“c79f45bf1bf1bc209abfafb9b112537981b24046”上不存在“paddle/fluid/lite/utils/any.h”
提交 d9ad276c 编写于 作者: Z zlsh80826

merge develop

...@@ -63,7 +63,28 @@ if(WIN32) ...@@ -63,7 +63,28 @@ if(WIN32)
set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT") set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT")
foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
CMAKE_C_FLAGS_MINSIZEREL CMAKE_C_FLAGS_RELWITHDEBINFO)
if(${flag_var} MATCHES "/MD")
string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}")
endif() endif()
endforeach(flag_var)
endif()
# windows build turn off warnings.
foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
CMAKE_C_FLAGS_MINSIZEREL CMAKE_C_FLAGS_RELWITHDEBINFO)
string(REGEX REPLACE "/W[1-4]" " /W0 " ${flag_var} "${${flag_var}}")
endforeach(flag_var)
foreach(flag_var CMAKE_CXX_FLAGS CMAKE_C_FLAGS)
set(${flag_var} "${${flag_var}} /w")
endforeach(flag_var)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /wd4068 /wd4129 /wd4244 /wd4267 /wd4297 /wd4530 /wd4577 /wd4819 /wd4838 /MP") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /wd4068 /wd4129 /wd4244 /wd4267 /wd4297 /wd4530 /wd4577 /wd4819 /wd4838 /MP")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4068 /wd4129 /wd4244 /wd4267 /wd4297 /wd4530 /wd4577 /wd4819 /wd4838 /MP") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4068 /wd4129 /wd4244 /wd4267 /wd4297 /wd4530 /wd4577 /wd4819 /wd4838 /MP")
......
...@@ -16,6 +16,7 @@ else() ...@@ -16,6 +16,7 @@ else()
set(paddle_known_gpu_archs8 "30 35 50 52 60 61") set(paddle_known_gpu_archs8 "30 35 50 52 60 61")
set(paddle_known_gpu_archs9 "30 35 50 52 60 61 70") set(paddle_known_gpu_archs9 "30 35 50 52 60 61 70")
set(paddle_known_gpu_archs10 "30 35 50 52 60 61 70 75") set(paddle_known_gpu_archs10 "30 35 50 52 60 61 70 75")
set(paddle_known_gpu_archs11 "52 60 61 70 75 80")
endif() endif()
###################################################################################### ######################################################################################
...@@ -106,6 +107,9 @@ function(select_nvcc_arch_flags out_variable) ...@@ -106,6 +107,9 @@ function(select_nvcc_arch_flags out_variable)
elseif(${CUDA_ARCH_NAME} STREQUAL "Maxwell") elseif(${CUDA_ARCH_NAME} STREQUAL "Maxwell")
set(cuda_arch_bin "50") set(cuda_arch_bin "50")
elseif(${CUDA_ARCH_NAME} STREQUAL "Pascal") elseif(${CUDA_ARCH_NAME} STREQUAL "Pascal")
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} LESS 10.0)
add_definitions("-DSUPPORTS_CUDA_FP16")
endif()
set(cuda_arch_bin "60 61") set(cuda_arch_bin "60 61")
elseif(${CUDA_ARCH_NAME} STREQUAL "Volta") elseif(${CUDA_ARCH_NAME} STREQUAL "Volta")
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} LESS 10.0) if (NOT ${CMAKE_CUDA_COMPILER_VERSION} LESS 10.0)
...@@ -188,6 +192,10 @@ elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) # CUDA 10.x ...@@ -188,6 +192,10 @@ elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) # CUDA 10.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs10}) set(paddle_known_gpu_archs ${paddle_known_gpu_archs10})
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__")
elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0) # CUDA 11.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs11})
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__")
endif() endif()
add_definitions("-DPADDLE_CUDA_BINVER=\"${CUDA_VERSION_MAJOR}${CUDA_VERSION_MINOR}\"") add_definitions("-DPADDLE_CUDA_BINVER=\"${CUDA_VERSION_MAJOR}${CUDA_VERSION_MINOR}\"")
......
...@@ -22,23 +22,8 @@ SET(CRYPTOPP_TAG CRYPTOPP_8_2_0) ...@@ -22,23 +22,8 @@ SET(CRYPTOPP_TAG CRYPTOPP_8_2_0)
IF(WIN32) IF(WIN32)
SET(CRYPTOPP_LIBRARIES "${CRYPTOPP_INSTALL_DIR}/lib/cryptopp-static.lib" CACHE FILEPATH "cryptopp library." FORCE) SET(CRYPTOPP_LIBRARIES "${CRYPTOPP_INSTALL_DIR}/lib/cryptopp-static.lib" CACHE FILEPATH "cryptopp library." FORCE)
SET(CRYPTOPP_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /MTd")
set(CompilerFlags
CMAKE_CXX_FLAGS
CMAKE_CXX_FLAGS_DEBUG
CMAKE_CXX_FLAGS_RELEASE
CMAKE_C_FLAGS
CMAKE_C_FLAGS_DEBUG
CMAKE_C_FLAGS_RELEASE
)
foreach(CompilerFlag ${CompilerFlags})
string(REPLACE "/MD" "/MT" ${CompilerFlag} "${${CompilerFlag}}")
endforeach()
ELSE(WIN32) ELSE(WIN32)
SET(CRYPTOPP_LIBRARIES "${CRYPTOPP_INSTALL_DIR}/lib/libcryptopp.a" CACHE FILEPATH "cryptopp library." FORCE) SET(CRYPTOPP_LIBRARIES "${CRYPTOPP_INSTALL_DIR}/lib/libcryptopp.a" CACHE FILEPATH "cryptopp library." FORCE)
SET(CRYPTOPP_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
ENDIF(WIN32) ENDIF(WIN32)
set(CRYPTOPP_CMAKE_ARGS ${COMMON_CMAKE_ARGS} set(CRYPTOPP_CMAKE_ARGS ${COMMON_CMAKE_ARGS}
...@@ -48,7 +33,7 @@ set(CRYPTOPP_CMAKE_ARGS ${COMMON_CMAKE_ARGS} ...@@ -48,7 +33,7 @@ set(CRYPTOPP_CMAKE_ARGS ${COMMON_CMAKE_ARGS}
-DCMAKE_INSTALL_LIBDIR=${CRYPTOPP_INSTALL_DIR}/lib -DCMAKE_INSTALL_LIBDIR=${CRYPTOPP_INSTALL_DIR}/lib
-DCMAKE_INSTALL_PREFIX=${CRYPTOPP_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${CRYPTOPP_INSTALL_DIR}
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_CXX_FLAGS=${CRYPTOPP_CMAKE_CXX_FLAGS} -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......
...@@ -19,7 +19,7 @@ SET(DGC_SOURCES_DIR "${THIRD_PARTY_PATH}/dgc/src/extern_dgc") ...@@ -19,7 +19,7 @@ SET(DGC_SOURCES_DIR "${THIRD_PARTY_PATH}/dgc/src/extern_dgc")
SET(DGC_INSTALL_DIR "${THIRD_PARTY_PATH}/install/dgc") SET(DGC_INSTALL_DIR "${THIRD_PARTY_PATH}/install/dgc")
SET(DGC_INCLUDE_DIR "${DGC_INSTALL_DIR}/include" CACHE PATH "dgc include directory." FORCE) SET(DGC_INCLUDE_DIR "${DGC_INSTALL_DIR}/include" CACHE PATH "dgc include directory." FORCE)
SET(DGC_LIBRARIES "${DGC_INSTALL_DIR}/lib/libdgc.a" CACHE FILEPATH "dgc library." FORCE) SET(DGC_LIBRARIES "${DGC_INSTALL_DIR}/lib/libdgc.a" CACHE FILEPATH "dgc library." FORCE)
SET(DGC_URL "http://fleet.bj.bcebos.com/collective_ef2216a.tgz") SET(DGC_URL "https://fleet.bj.bcebos.com/dgc/collective_f66ef73.tgz")
INCLUDE_DIRECTORIES(${DGC_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${DGC_INCLUDE_DIR})
cache_third_party(extern_dgc cache_third_party(extern_dgc
...@@ -30,7 +30,7 @@ ExternalProject_Add( ...@@ -30,7 +30,7 @@ ExternalProject_Add(
extern_dgc extern_dgc
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
"${DGC_DOWNLOAD_CMD}" "${DGC_DOWNLOAD_CMD}"
URL_MD5 "2f67549fd5f1262383d83289abc4f88f" URL_MD5 "94e6fa1bc97169d0e1aad44570fe3251"
PREFIX "${DGC_PREFIX_DIR}" PREFIX "${DGC_PREFIX_DIR}"
SOURCE_DIR "${DGC_SOURCES_DIR}" SOURCE_DIR "${DGC_SOURCES_DIR}"
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
......
...@@ -34,7 +34,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) ...@@ -34,7 +34,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite) set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite)
if(NOT LITE_GIT_TAG) if(NOT LITE_GIT_TAG)
set(LITE_GIT_TAG dfdfa6440c83bf0b415f9f5a9ff84842ce0bb0fa) set(LITE_GIT_TAG 6d2b2a4028a58715b01887b04eb9bff8432eb184)
endif() endif()
if(NOT CUDA_ARCH_NAME) if(NOT CUDA_ARCH_NAME)
......
...@@ -19,8 +19,8 @@ SET(MKLDNN_PREFIX_DIR ${THIRD_PARTY_PATH}/mkldnn) ...@@ -19,8 +19,8 @@ SET(MKLDNN_PREFIX_DIR ${THIRD_PARTY_PATH}/mkldnn)
SET(MKLDNN_SOURCE_DIR ${THIRD_PARTY_PATH}/mkldnn/src/extern_mkldnn) SET(MKLDNN_SOURCE_DIR ${THIRD_PARTY_PATH}/mkldnn/src/extern_mkldnn)
SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn) SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn)
SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE) SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE)
SET(MKLDNN_REPOSITORY https://github.com/intel/mkl-dnn.git) SET(MKLDNN_REPOSITORY https://github.com/oneapi-src/oneDNN.git)
SET(MKLDNN_TAG 1ea812f4f5aa1bd989372a23ab50d0f0f81ee677) SET(MKLDNN_TAG 64a48f9565aa72f6359917b3406328075a409939)
# Introduce variables: # Introduce variables:
# * CMAKE_INSTALL_LIBDIR # * CMAKE_INSTALL_LIBDIR
......
...@@ -18,7 +18,7 @@ SET(WARPCTC_PREFIX_DIR ${THIRD_PARTY_PATH}/warpctc) ...@@ -18,7 +18,7 @@ SET(WARPCTC_PREFIX_DIR ${THIRD_PARTY_PATH}/warpctc)
SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc) SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc)
SET(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc) SET(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc)
set(WARPCTC_REPOSITORY https://github.com/baidu-research/warp-ctc.git) set(WARPCTC_REPOSITORY https://github.com/baidu-research/warp-ctc.git)
set(WARPCTC_TAG bc29dcfff07ced1c7a19a4ecee48e5ad583cef8e) set(WARPCTC_TAG fc7f226b93758216a03b1be9d24593a12819b984)
SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include" SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include"
CACHE PATH "Warp-ctc Directory" FORCE) CACHE PATH "Warp-ctc Directory" FORCE)
......
...@@ -28,7 +28,15 @@ function(CheckCompilerCXX11Flag) ...@@ -28,7 +28,15 @@ function(CheckCompilerCXX11Flag)
endfunction() endfunction()
CheckCompilerCXX11Flag() CheckCompilerCXX11Flag()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") if (WITH_GPU)
if (${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
endif()
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
endif()
# safe_set_flag # safe_set_flag
# #
# Set a compile flag only if compiler is support # Set a compile flag only if compiler is support
...@@ -82,20 +90,6 @@ macro(safe_set_nvflag flag_name) ...@@ -82,20 +90,6 @@ macro(safe_set_nvflag flag_name)
endif() endif()
endmacro() endmacro()
macro(safe_set_static_flag) # set c_flags and cxx_flags to static or shared
if (BUILD_SHARED_LIBS)
return() # if build shared libs, the flags keep same with '/MD'
endif(BUILD_SHARED_LIBS)
foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
CMAKE_C_FLAGS_MINSIZEREL CMAKE_C_FLAGS_RELWITHDEBINFO)
if(${flag_var} MATCHES "/MD")
string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}")
endif(${flag_var} MATCHES "/MD")
endforeach(flag_var)
endmacro()
CHECK_CXX_SYMBOL_EXISTS(UINT64_MAX "stdint.h" UINT64_MAX_EXISTS) CHECK_CXX_SYMBOL_EXISTS(UINT64_MAX "stdint.h" UINT64_MAX_EXISTS)
if(NOT UINT64_MAX_EXISTS) if(NOT UINT64_MAX_EXISTS)
...@@ -221,20 +215,3 @@ endforeach() ...@@ -221,20 +215,3 @@ endforeach()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${SAFE_GPU_COMMON_FLAGS}") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${SAFE_GPU_COMMON_FLAGS}")
if(WIN32)
# windows build turn off warnings.
if(MSVC_STATIC_CRT)
safe_set_static_flag()
endif()
foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
CMAKE_C_FLAGS_MINSIZEREL CMAKE_C_FLAGS_RELWITHDEBINFO)
string(REGEX REPLACE "/W[1-4]" " /W0 " ${flag_var} "${${flag_var}}")
endforeach(flag_var)
foreach(flag_var CMAKE_CXX_FLAGS CMAKE_C_FLAGS)
set(${flag_var} "${${flag_var}} /w")
endforeach(flag_var)
endif()
...@@ -386,7 +386,7 @@ function(cc_test_run TARGET_NAME) ...@@ -386,7 +386,7 @@ function(cc_test_run TARGET_NAME)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
# No unit test should exceed 2 minutes. # No unit test should exceed 2 minutes.
if (APPLE OR WIN32) if (APPLE OR WIN32)
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600) set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 150)
else() else()
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 120) set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 120)
endif() endif()
...@@ -748,7 +748,7 @@ function(py_test TARGET_NAME) ...@@ -748,7 +748,7 @@ function(py_test TARGET_NAME)
endif() endif()
if (APPLE OR WIN32) if (APPLE OR WIN32)
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600) set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 150)
else() else()
# No unit test should exceed 2 minutes in Linux. # No unit test should exceed 2 minutes in Linux.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 120) set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 120)
......
...@@ -13,18 +13,18 @@ ...@@ -13,18 +13,18 @@
# limitations under the License. # limitations under the License.
# make package for paddle fluid shared and static library # make package for paddle fluid shared and static library
set(FLUID_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_install_dir" CACHE STRING set(PADDLE_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_install_dir" CACHE STRING
"A path setting fluid shared and static libraries") "A path setting paddle shared and static libraries")
set(FLUID_INFERENCE_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_install_dir" CACHE STRING set(PADDLE_INFERENCE_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_inference_install_dir" CACHE STRING
"A path setting fluid inference shared and static libraries") "A path setting paddle inference shared and static libraries")
# TODO(zhaolong) # TODO(zhaolong)
# At present, the size of static lib in Windows exceeds the system limit, # At present, the size of static lib in Windows exceeds the system limit,
# so the generation of static lib is temporarily turned off. # so the generation of static lib is temporarily turned off.
if(WIN32) if(WIN32)
#todo: remove the option #todo: remove the option
option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." OFF) option(WITH_STATIC_LIB "Compile demo with static/shared library, default use dynamic." OFF)
if(NOT PYTHON_EXECUTABLE) if(NOT PYTHON_EXECUTABLE)
FIND_PACKAGE(PythonInterp REQUIRED) FIND_PACKAGE(PythonInterp REQUIRED)
endif() endif()
...@@ -142,14 +142,14 @@ set(inference_lib_deps third_party paddle_fluid paddle_fluid_c paddle_fluid_shar ...@@ -142,14 +142,14 @@ set(inference_lib_deps third_party paddle_fluid paddle_fluid_c paddle_fluid_shar
add_custom_target(inference_lib_dist DEPENDS ${inference_lib_deps}) add_custom_target(inference_lib_dist DEPENDS ${inference_lib_deps})
set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/threadpool") set(dst_dir "${PADDLE_INFERENCE_INSTALL_DIR}/third_party/threadpool")
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${THREADPOOL_INCLUDE_DIR}/ThreadPool.h SRCS ${THREADPOOL_INCLUDE_DIR}/ThreadPool.h
DSTS ${dst_dir}) DSTS ${dst_dir})
# Only GPU need cudaErrorMessage.pb # Only GPU need cudaErrorMessage.pb
IF(WITH_GPU) IF(WITH_GPU)
set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/cudaerror/data") set(dst_dir "${PADDLE_INFERENCE_INSTALL_DIR}/third_party/cudaerror/data")
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${cudaerror_INCLUDE_DIR} SRCS ${cudaerror_INCLUDE_DIR}
DSTS ${dst_dir}) DSTS ${dst_dir})
...@@ -158,65 +158,62 @@ ENDIF() ...@@ -158,65 +158,62 @@ ENDIF()
# CMakeCache Info # CMakeCache Info
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${CMAKE_CURRENT_BINARY_DIR}/CMakeCache.txt SRCS ${CMAKE_CURRENT_BINARY_DIR}/CMakeCache.txt
DSTS ${FLUID_INFERENCE_INSTALL_DIR}) DSTS ${PADDLE_INFERENCE_INSTALL_DIR})
copy_part_of_thrid_party(inference_lib_dist ${FLUID_INFERENCE_INSTALL_DIR}) copy_part_of_thrid_party(inference_lib_dist ${PADDLE_INFERENCE_INSTALL_DIR})
set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid") set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid")
if(WIN32) if(WIN32)
if(WITH_STATIC_LIB) if(WITH_STATIC_LIB)
set(paddle_fluid_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/${CMAKE_BUILD_TYPE}/libpaddle_fluid.lib) set(paddle_fluid_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/${CMAKE_BUILD_TYPE}/libpaddle_fluid.lib
${PADDLE_BINARY_DIR}/paddle/fluid/inference/${CMAKE_BUILD_TYPE}/paddle_fluid.*)
else() else()
set(paddle_fluid_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/${CMAKE_BUILD_TYPE}/paddle_fluid.dll set(paddle_fluid_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/${CMAKE_BUILD_TYPE}/paddle_fluid.dll
${PADDLE_BINARY_DIR}/paddle/fluid/inference/${CMAKE_BUILD_TYPE}/paddle_fluid.lib) ${PADDLE_BINARY_DIR}/paddle/fluid/inference/${CMAKE_BUILD_TYPE}/paddle_fluid.lib)
endif() endif()
else(WIN32)
set(paddle_fluid_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.*)
endif(WIN32)
if(WIN32 AND NOT WITH_STATIC_LIB)
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${src_dir}/inference/api/paddle_*.h ${paddle_fluid_lib} SRCS ${src_dir}/inference/api/paddle_*.h ${paddle_fluid_lib}
DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include ${FLUID_INFERENCE_INSTALL_DIR}/paddle/lib DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/lib
${FLUID_INFERENCE_INSTALL_DIR}/paddle/lib) ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/lib)
else() else(WIN32)
set(paddle_fluid_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.*)
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${src_dir}/inference/api/paddle_*.h ${paddle_fluid_lib} SRCS ${src_dir}/inference/api/paddle_*.h ${paddle_fluid_lib}
DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include ${FLUID_INFERENCE_INSTALL_DIR}/paddle/lib) DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/lib)
endif() endif(WIN32)
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${CMAKE_BINARY_DIR}/paddle/fluid/framework/framework.pb.h SRCS ${CMAKE_BINARY_DIR}/paddle/fluid/framework/framework.pb.h
DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include/internal) DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/internal)
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/framework/io/crypto/cipher.h SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/framework/io/crypto/cipher.h
DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include/crypto/) DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/crypto/)
include_directories(${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io) include_directories(${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io)
# CAPI inference library for only inference # CAPI inference library for only inference
set(FLUID_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_c_install_dir" CACHE STRING set(PADDLE_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_inference_c_install_dir" CACHE STRING
"A path setting CAPI fluid inference shared") "A path setting CAPI paddle inference shared")
copy_part_of_thrid_party(inference_lib_dist ${FLUID_INFERENCE_C_INSTALL_DIR}) copy_part_of_thrid_party(inference_lib_dist ${PADDLE_INFERENCE_C_INSTALL_DIR})
set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid") set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid")
set(paddle_fluid_c_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/capi/libpaddle_fluid_c.*) set(paddle_fluid_c_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/capi/libpaddle_fluid_c.*)
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${src_dir}/inference/capi/paddle_c_api.h ${paddle_fluid_c_lib} SRCS ${src_dir}/inference/capi/paddle_c_api.h ${paddle_fluid_c_lib}
DSTS ${FLUID_INFERENCE_C_INSTALL_DIR}/paddle/include ${FLUID_INFERENCE_C_INSTALL_DIR}/paddle/lib) DSTS ${PADDLE_INFERENCE_C_INSTALL_DIR}/paddle/include ${PADDLE_INFERENCE_C_INSTALL_DIR}/paddle/lib)
# fluid library for both train and inference # fluid library for both train and inference
set(fluid_lib_deps inference_lib_dist) set(fluid_lib_deps inference_lib_dist)
add_custom_target(fluid_lib_dist ALL DEPENDS ${fluid_lib_deps}) add_custom_target(fluid_lib_dist ALL DEPENDS ${fluid_lib_deps})
set(dst_dir "${FLUID_INSTALL_DIR}/paddle/fluid") set(dst_dir "${PADDLE_INSTALL_DIR}/paddle/fluid")
set(module "inference") set(module "inference")
if(WIN32 AND NOT WITH_STATIC_LIB) if(WIN32)
copy(fluid_lib_dist copy(fluid_lib_dist
SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/api/paddle_*.h ${paddle_fluid_lib} SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/api/paddle_*.h ${paddle_fluid_lib}
DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
) )
else() else()
copy(fluid_lib_dist copy(fluid_lib_dist
SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/api/paddle_*.h ${paddle_fluid_lib} SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/api/paddle_*.h ${paddle_fluid_lib}
DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
...@@ -273,22 +270,22 @@ copy(fluid_lib_dist ...@@ -273,22 +270,22 @@ copy(fluid_lib_dist
DSTS ${dst_dir}/${module} DSTS ${dst_dir}/${module}
) )
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/eigen3") set(dst_dir "${PADDLE_INSTALL_DIR}/third_party/eigen3")
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${EIGEN_INCLUDE_DIR}/Eigen/Core ${EIGEN_INCLUDE_DIR}/Eigen/src ${EIGEN_INCLUDE_DIR}/unsupported/Eigen SRCS ${EIGEN_INCLUDE_DIR}/Eigen/Core ${EIGEN_INCLUDE_DIR}/Eigen/src ${EIGEN_INCLUDE_DIR}/unsupported/Eigen
DSTS ${dst_dir}/Eigen ${dst_dir}/Eigen ${dst_dir}/unsupported) DSTS ${dst_dir}/Eigen ${dst_dir}/Eigen ${dst_dir}/unsupported)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/boost") set(dst_dir "${PADDLE_INSTALL_DIR}/third_party/boost")
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${BOOST_INCLUDE_DIR}/boost SRCS ${BOOST_INCLUDE_DIR}/boost
DSTS ${dst_dir}) DSTS ${dst_dir})
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/dlpack") set(dst_dir "${PADDLE_INSTALL_DIR}/third_party/dlpack")
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${DLPACK_INCLUDE_DIR}/dlpack SRCS ${DLPACK_INCLUDE_DIR}/dlpack
DSTS ${dst_dir}) DSTS ${dst_dir})
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/zlib") set(dst_dir "${PADDLE_INSTALL_DIR}/third_party/install/zlib")
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${ZLIB_INCLUDE_DIR} ${ZLIB_LIBRARIES} SRCS ${ZLIB_INCLUDE_DIR} ${ZLIB_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib) DSTS ${dst_dir} ${dst_dir}/lib)
...@@ -296,8 +293,8 @@ copy(inference_lib_dist ...@@ -296,8 +293,8 @@ copy(inference_lib_dist
# CMakeCache Info # CMakeCache Info
copy(fluid_lib_dist copy(fluid_lib_dist
SRCS ${FLUID_INFERENCE_INSTALL_DIR}/third_party ${CMAKE_CURRENT_BINARY_DIR}/CMakeCache.txt SRCS ${PADDLE_INFERENCE_INSTALL_DIR}/third_party ${CMAKE_CURRENT_BINARY_DIR}/CMakeCache.txt
DSTS ${FLUID_INSTALL_DIR} ${FLUID_INSTALL_DIR} DSTS ${PADDLE_INSTALL_DIR} ${PADDLE_INSTALL_DIR}
) )
# paddle fluid version # paddle fluid version
...@@ -323,6 +320,6 @@ function(version version_file) ...@@ -323,6 +320,6 @@ function(version version_file)
endif() endif()
endfunction() endfunction()
version(${FLUID_INSTALL_DIR}/version.txt) version(${PADDLE_INSTALL_DIR}/version.txt)
version(${FLUID_INFERENCE_INSTALL_DIR}/version.txt) version(${PADDLE_INFERENCE_INSTALL_DIR}/version.txt)
version(${FLUID_INFERENCE_C_INSTALL_DIR}/version.txt) version(${PADDLE_INFERENCE_C_INSTALL_DIR}/version.txt)
...@@ -127,7 +127,8 @@ function(op_library TARGET) ...@@ -127,7 +127,8 @@ function(op_library TARGET)
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op") "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op"
"fused_bn_add_activation_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}") if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()
...@@ -138,12 +139,17 @@ function(op_library TARGET) ...@@ -138,12 +139,17 @@ function(op_library TARGET)
# And for detail pybind information, please see generated paddle/pybind/pybind.h. # And for detail pybind information, please see generated paddle/pybind/pybind.h.
file(READ ${TARGET}.cc TARGET_CONTENT) file(READ ${TARGET}.cc TARGET_CONTENT)
string(REGEX MATCH "REGISTER_OPERATOR\\(.*REGISTER_OPERATOR\\(" multi_register "${TARGET_CONTENT}") string(REGEX MATCH "REGISTER_OPERATOR\\(.*REGISTER_OPERATOR\\(" multi_register "${TARGET_CONTENT}")
string(REGEX MATCH "REGISTER_OPERATOR\\([a-z0-9_]*," one_register "${multi_register}") # [ \t\r\n]* is used for blank characters
string(REGEX MATCH "REGISTER_OPERATOR\\([ \t\r\n]*[a-z0-9_]*," one_register "${multi_register}")
if (one_register STREQUAL "") if (one_register STREQUAL "")
string(REPLACE "_op" "" TARGET "${TARGET}") string(REPLACE "_op" "" TARGET "${TARGET}")
else () else ()
string(REPLACE "REGISTER_OPERATOR(" "" TARGET "${one_register}") string(REPLACE "REGISTER_OPERATOR(" "" TARGET "${one_register}")
string(REPLACE "," "" TARGET "${TARGET}") string(REPLACE "," "" TARGET "${TARGET}")
# [ \t\r\n]+ is used for blank characters.
# Here we use '+' instead of '*' since it is a REPLACE operation.
string(REGEX REPLACE "[ \t\r\n]+" "" TARGET "${TARGET}")
endif() endif()
# pybind USE_NO_KERNEL_OP # pybind USE_NO_KERNEL_OP
......
...@@ -243,9 +243,10 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC)) ...@@ -243,9 +243,10 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
ENDIF() ENDIF()
if(WITH_GPU) if(WITH_GPU)
if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0)
include(external/cub) # download cub include(external/cub) # download cub
list(APPEND third_party_deps extern_cub) list(APPEND third_party_deps extern_cub)
endif()
set(CUDAERROR_URL "http://paddlepaddledeps.bj.bcebos.com/cudaErrorMessage.tar.gz" CACHE STRING "" FORCE) set(CUDAERROR_URL "http://paddlepaddledeps.bj.bcebos.com/cudaErrorMessage.tar.gz" CACHE STRING "" FORCE)
file_download_and_uncompress(${CUDAERROR_URL} "cudaerror") # download file cudaErrorMessage file_download_and_uncompress(${CUDAERROR_URL} "cudaerror") # download file cudaErrorMessage
endif(WITH_GPU) endif(WITH_GPU)
......
# Paddle 预测golang API # Paddle 预测golang API
## 安装 ## 安装
首先cmake编译时打开`-DON_INFER=ON`,在编译目录下得到``fluid_inference_c_install_dir``,将该目录移动到当前目录中并重命名为`paddle_c` 首先cmake编译时打开`-DON_INFER=ON`,在编译目录下得到``paddle_inference_c_install_dir``,将该目录移动到当前目录中并重命名为`paddle_c`
## 在Go中使用Paddle预测 ## 在Go中使用Paddle预测
首先创建预测配置 首先创建预测配置
......
...@@ -154,10 +154,17 @@ func (config *AnalysisConfig) EnableMkldnnQuantizer() { ...@@ -154,10 +154,17 @@ func (config *AnalysisConfig) EnableMkldnnQuantizer() {
C.PD_EnableMkldnnQuantizer(config.c) C.PD_EnableMkldnnQuantizer(config.c)
} }
func (config *AnalysisConfig) EnableMkldnnBfloat16() {
C.PD_EnableMkldnnBfloat16(config.c)
}
func (config *AnalysisConfig) MkldnnQuantizerEnabled() bool { func (config *AnalysisConfig) MkldnnQuantizerEnabled() bool {
return ConvertCBooleanToGo(C.PD_MkldnnQuantizerEnabled(config.c)) return ConvertCBooleanToGo(C.PD_MkldnnQuantizerEnabled(config.c))
} }
func (config *AnalysisConfig) MkldnnBfloat16Enabled() bool {
return ConvertCBooleanToGo(C.PD_MkldnnBfloat16Enabled(config.c))
}
// SetModelBuffer // SetModelBuffer
// ModelFromMemory // ModelFromMemory
......
...@@ -272,7 +272,7 @@ cc_test(op_compatible_info_test SRCS op_compatible_info_test.cc DEPS op_compatib ...@@ -272,7 +272,7 @@ cc_test(op_compatible_info_test SRCS op_compatible_info_test.cc DEPS op_compatib
cc_library(save_load_util SRCS save_load_util DEPS tensor scope layer) cc_library(save_load_util SRCS save_load_util DEPS tensor scope layer)
cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer) cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer)
cc_library(generator SRCS generator.cc) cc_library(generator SRCS generator.cc DEPS enforce place)
# Get the current working branch # Get the current working branch
execute_process( execute_process(
......
...@@ -49,7 +49,8 @@ std::vector<std::string> PD_GetGradOpDescStrs( ...@@ -49,7 +49,8 @@ std::vector<std::string> PD_GetGradOpDescStrs(
for (size_t i = 0; i < op_num; ++i) { for (size_t i = 0; i < op_num; ++i) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
grad_op_descs[i]->Proto()->SerializePartialToString(&ret[i]), true, grad_op_descs[i]->Proto()->SerializePartialToString(&ret[i]), true,
"Cannot serialize message."); paddle::platform::errors::Unavailable(
"Cannot serialize operator desc message."));
} }
} }
return ret; return ret;
......
...@@ -527,6 +527,8 @@ bool MultiSlotDataFeed::CheckFile(const char* filename) { ...@@ -527,6 +527,8 @@ bool MultiSlotDataFeed::CheckFile(const char* filename) {
VLOG(0) << "error: the number of ids is a negative number: " << num; VLOG(0) << "error: the number of ids is a negative number: " << num;
VLOG(0) << "please check line<" << instance_cout << "> in file<" VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">"; << filename << ">";
VLOG(0) << "Error occured when parsing " << i
<< " th slot with total slots number: " << all_slots_.size();
return false; return false;
} else if (num == 0) { } else if (num == 0) {
VLOG(0) VLOG(0)
...@@ -536,42 +538,66 @@ bool MultiSlotDataFeed::CheckFile(const char* filename) { ...@@ -536,42 +538,66 @@ bool MultiSlotDataFeed::CheckFile(const char* filename) {
"characters."; "characters.";
VLOG(0) << "please check line<" << instance_cout << "> in file<" VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">"; << filename << ">";
VLOG(0) << "Error occured when parsing " << i
<< " th slot with total slots number: " << all_slots_.size();
return false; return false;
} else if (errno == ERANGE || num > INT_MAX) { } else if (errno == ERANGE || num > INT_MAX) {
VLOG(0) << "error: the number of ids greater than INT_MAX"; VLOG(0) << "error: the number of ids greater than INT_MAX";
VLOG(0) << "please check line<" << instance_cout << "> in file<" VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">"; << filename << ">";
VLOG(0) << "Error occured when parsing " << i
<< " th slot with total slots number: " << all_slots_.size();
return false; return false;
} }
if (all_slots_type_[i] == "float") { if (all_slots_type_[i] == "float") {
for (int i = 0; i < num; ++i) { for (int j = 0; j < num; ++j) {
strtof(endptr, &endptr); strtof(endptr, &endptr);
if (errno == ERANGE) { if (errno == ERANGE) {
VLOG(0) << "error: the value is out of the range of " VLOG(0) << "error: the value is out of the range of "
"representable values for float"; "representable values for float";
VLOG(0) << "please check line<" << instance_cout << "> in file<" VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">"; << filename << ">";
VLOG(0) << "Error occured when parsing " << i
<< " th slot with total slots number: "
<< all_slots_.size();
VLOG(0) << "and in this slot: " << j
<< " th id with total id number: " << num;
return false; return false;
} }
if (i + 1 != num && endptr - str == len) { if (j + 1 != num && endptr - str == len) {
VLOG(0) << "error: there is a wrong with the number of ids."; VLOG(0) << "error: there is a wrong with the number of ids.";
VLOG(0) << "Error occured when parsing " << i
<< " th slot with total slots number: "
<< all_slots_.size();
VLOG(0) << "and in this slot: " << j
<< " th id with total id number: " << num;
VLOG(0) << "please check line<" << instance_cout << "> in file<" VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">"; << filename << ">";
return false; return false;
} }
} }
} else if (all_slots_type_[i] == "uint64") { } else if (all_slots_type_[i] == "uint64") {
for (int i = 0; i < num; ++i) { for (int j = 0; j < num; ++j) {
strtoull(endptr, &endptr, 10); strtoull(endptr, &endptr, 10);
if (errno == ERANGE) { if (errno == ERANGE) {
VLOG(0) << "error: the value is out of the range of " VLOG(0) << "error: the value is out of the range of "
"representable values for uint64_t"; "representable values for uint64_t";
VLOG(0) << "Error occured when parsing " << i
<< " th slot with total slots number: "
<< all_slots_.size();
VLOG(0) << "and in this slot: " << j
<< " th id with total id number: " << num;
VLOG(0) << "please check line<" << instance_cout << "> in file<" VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">"; << filename << ">";
return false; return false;
} }
if (i + 1 != num && endptr - str == len) { if (j + 1 != num && endptr - str == len) {
VLOG(0) << "error: there is a wrong with the number of ids."; VLOG(0) << "error: there is a wrong with the number of ids.";
VLOG(0) << "Error occured when parsing " << i
<< " th slot with total slots number: "
<< all_slots_.size();
VLOG(0) << "and in this slot: " << j
<< " th id with total id number: " << num;
VLOG(0) << "please check line<" << instance_cout << "> in file<" VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">"; << filename << ">";
return false; return false;
...@@ -632,8 +658,13 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe( ...@@ -632,8 +658,13 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
"The number of ids can not be zero, you need padding " "The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with " "it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable " "the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s", "characters.\nplease check this error line: %s, \n Specifically, "
str)); "something wrong happened(the length of this slot's feasign is 0)"
"when we parse the %d th slots."
"Maybe something wrong around this slot",
"\nWe detect the feasign number of this slot is %d, "
"which is illegal.",
str, i, num));
if (idx != -1) { if (idx != -1) {
(*instance)[idx].Init(all_slots_type_[i]); (*instance)[idx].Init(all_slots_type_[i]);
if ((*instance)[idx].GetType()[0] == 'f') { // float if ((*instance)[idx].GetType()[0] == 'f') { // float
...@@ -683,8 +714,13 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) { ...@@ -683,8 +714,13 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
"The number of ids can not be zero, you need padding " "The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with " "it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable " "the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s.", "characters.\nplease check this error line: %s, \n Specifically, "
str)); "something wrong happened(the length of this slot's feasign is 0)"
"when we parse the %d th slots."
"Maybe something wrong around this slot",
"\nWe detect the feasign number of this slot is %d, "
"which is illegal.",
str, i, num));
if (idx != -1) { if (idx != -1) {
(*instance)[idx].Init(all_slots_type_[i]); (*instance)[idx].Init(all_slots_type_[i]);
...@@ -916,8 +952,13 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { ...@@ -916,8 +952,13 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
"The number of ids can not be zero, you need padding " "The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with " "it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable " "the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s.", "characters.\nplease check this error line: %s, \n Specifically, "
str)); "something wrong happened(the length of this slot's feasign is 0)"
"when we parse the %d th slots."
"Maybe something wrong around this slot",
"\nWe detect the feasign number of this slot is %d, "
"which is illegal.",
str, i, num));
if (idx != -1) { if (idx != -1) {
if (all_slots_type_[i][0] == 'f') { // float if (all_slots_type_[i][0] == 'f') { // float
for (int j = 0; j < num; ++j) { for (int j = 0; j < num; ++j) {
...@@ -982,8 +1023,13 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) { ...@@ -982,8 +1023,13 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) {
"The number of ids can not be zero, you need padding " "The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with " "it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable " "the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s.", "characters.\nplease check this error line: %s, \n Specifically, "
str)); "something wrong happened(the length of this slot's feasign is 0)"
"when we parse the %d th slots."
"Maybe something wrong around this slot",
"\nWe detect the feasign number of this slot is %d, "
"which is illegal.",
str, i, num));
if (idx != -1) { if (idx != -1) {
if (all_slots_type_[i][0] == 'f') { // float if (all_slots_type_[i][0] == 'f') { // float
......
...@@ -116,6 +116,8 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) { ...@@ -116,6 +116,8 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
return platform::to_void_cast(tensor.data<unsigned char>()); return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s32: case mkldnn::memory::data_type::s32:
return platform::to_void_cast(tensor.data<int32_t>()); return platform::to_void_cast(tensor.data<int32_t>());
case mkldnn::memory::data_type::bf16:
return platform::to_void_cast(tensor.data<paddle::platform::bfloat16>());
default: default:
PADDLE_THROW( PADDLE_THROW(
platform::errors::InvalidArgument("Wrong mkldnn type provided.")); platform::errors::InvalidArgument("Wrong mkldnn type provided."));
......
...@@ -61,7 +61,8 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) { ...@@ -61,7 +61,8 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
{DataTypeTrait<float>::DataType(), MKLDNNDataType::f32}, {DataTypeTrait<float>::DataType(), MKLDNNDataType::f32},
{DataTypeTrait<int8_t>::DataType(), MKLDNNDataType::s8}, {DataTypeTrait<int8_t>::DataType(), MKLDNNDataType::s8},
{DataTypeTrait<uint8_t>::DataType(), MKLDNNDataType::u8}, {DataTypeTrait<uint8_t>::DataType(), MKLDNNDataType::u8},
{DataTypeTrait<int32_t>::DataType(), MKLDNNDataType::s32}}; {DataTypeTrait<int32_t>::DataType(), MKLDNNDataType::s32},
{DataTypeTrait<platform::bfloat16>::DataType(), MKLDNNDataType::bf16}};
auto iter = dict.find(static_cast<int>(type)); auto iter = dict.find(static_cast<int>(type));
if (iter != dict.end()) return iter->second; if (iter != dict.end()) return iter->second;
return MKLDNNDataType::undef; return MKLDNNDataType::undef;
...@@ -74,6 +75,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, ...@@ -74,6 +75,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type, const OpKernelType& expected_kernel_type,
const Tensor& in, Tensor* out); const Tensor& in, Tensor* out);
void* GetDataFromTensor(const Tensor& tensor, MKLDNNDataType type);
#endif #endif
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to); std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);
......
...@@ -43,3 +43,17 @@ TEST(DataTransform, DataLayoutFunction) { ...@@ -43,3 +43,17 @@ TEST(DataTransform, DataLayoutFunction) {
EXPECT_TRUE(in.layout() == paddle::framework::DataLayout::kNHWC); EXPECT_TRUE(in.layout() == paddle::framework::DataLayout::kNHWC);
EXPECT_TRUE(in.dims() == paddle::framework::make_ddim({2, 3, 1, 2})); EXPECT_TRUE(in.dims() == paddle::framework::make_ddim({2, 3, 1, 2}));
} }
#ifdef PADDLE_WITH_MKLDNN
TEST(DataTransform, GetDataFromTensorDNNL) {
auto place = paddle::platform::CPUPlace();
paddle::framework::Tensor in = paddle::framework::Tensor();
in.mutable_data<paddle::platform::bfloat16>(
paddle::framework::make_ddim({2, 3, 1, 2}), place);
void* in_data =
paddle::framework::GetDataFromTensor(in, dnnl::memory::data_type::bf16);
EXPECT_EQ(in_data, paddle::platform::to_void_cast(
in.data<paddle::platform::bfloat16>()));
}
#endif
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <unordered_map> #include <unordered_map>
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
using bfloat16 = paddle::platform::bfloat16;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -17,6 +17,8 @@ limitations under the License. */ ...@@ -17,6 +17,8 @@ limitations under the License. */
#include <typeindex> #include <typeindex>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
...@@ -39,6 +41,7 @@ struct DataTypeTrait<void> { ...@@ -39,6 +41,7 @@ struct DataTypeTrait<void> {
#define _ForEachDataType_(callback) \ #define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \ _ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \ _ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16); \
_ForEachDataTypeHelper_(callback, double, FP64); \ _ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \ _ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \ _ForEachDataTypeHelper_(callback, int64_t, INT64); \
......
...@@ -38,3 +38,25 @@ TEST(DataType, float16) { ...@@ -38,3 +38,25 @@ TEST(DataType, float16) {
std::string type = "::paddle::platform::float16"; std::string type = "::paddle::platform::float16";
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str()); EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
} }
TEST(DataType, bfloat16) {
using paddle::framework::Tensor;
using paddle::platform::CPUPlace;
using paddle::platform::bfloat16;
namespace f = paddle::framework;
f::proto::VarType::Type dtype = f::proto::VarType::BF16;
Tensor tensor;
CPUPlace cpu;
tensor.mutable_data(cpu, dtype);
// test bf16 tensor
EXPECT_EQ(tensor.type(), f::ToDataType(typeid(bfloat16)));
// test bf16 size
EXPECT_EQ(f::SizeOfType(dtype), 2u);
// test debug info
std::string type = "::paddle::platform::bfloat16";
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
}
...@@ -77,6 +77,10 @@ void TransDataType(const OpKernelType& kernel_type_for_var, ...@@ -77,6 +77,10 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
framework::VisitDataType(dst_type, framework::VisitDataType(dst_type,
CastDataType<platform::float16>(in, out, ctx)); CastDataType<platform::float16>(in, out, ctx));
break; break;
case proto::VarType::BF16:
framework::VisitDataType(dst_type,
CastDataType<platform::bfloat16>(in, out, ctx));
break;
case proto::VarType::FP32: case proto::VarType::FP32:
framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx)); framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
break; break;
......
...@@ -24,6 +24,11 @@ TEST(DataTypeTransform, CPUTransform) { ...@@ -24,6 +24,11 @@ TEST(DataTypeTransform, CPUTransform) {
paddle::framework::DataLayout::kAnyLayout, paddle::framework::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain); paddle::framework::LibraryType::kPlain);
auto kernel_bf16 = paddle::framework::OpKernelType(
paddle::framework::proto::VarType::BF16, place,
paddle::framework::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_fp32 = paddle::framework::OpKernelType( auto kernel_fp32 = paddle::framework::OpKernelType(
paddle::framework::proto::VarType::FP32, place, paddle::framework::proto::VarType::FP32, place,
paddle::framework::DataLayout::kAnyLayout, paddle::framework::DataLayout::kAnyLayout,
...@@ -189,4 +194,120 @@ TEST(DataTypeTransform, CPUTransform) { ...@@ -189,4 +194,120 @@ TEST(DataTypeTransform, CPUTransform) {
static_cast<paddle::platform::float16>(in_data_bool[i]).x); static_cast<paddle::platform::float16>(in_data_bool[i]).x);
} }
} }
// data type transform from/to bfloat16
{
paddle::framework::Tensor in;
paddle::framework::Tensor out;
paddle::platform::bfloat16* ptr =
in.mutable_data<paddle::platform::bfloat16>(
paddle::framework::make_ddim({2, 3}), place);
int data_number = 2 * 3;
for (int i = 0; i < data_number; ++i) {
ptr[i] = i;
}
// transform from bfloat16 to other data types
paddle::framework::TransDataType(kernel_bf16, kernel_fp32, in, &out);
float* out_data_float = out.data<float>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_float[i], static_cast<float>(ptr[i]));
}
paddle::framework::TransDataType(kernel_bf16, kernel_fp64, in, &out);
double* out_data_double = out.data<double>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_double[i], static_cast<double>(ptr[i]));
}
paddle::framework::TransDataType(kernel_bf16, kernel_int32, in, &out);
int* out_data_int = out.data<int>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_int[i], static_cast<int>(ptr[i]));
}
paddle::framework::TransDataType(kernel_bf16, kernel_int64, in, &out);
int64_t* out_data_int64 = out.data<int64_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i]));
}
paddle::framework::TransDataType(kernel_bf16, kernel_bool, in, &out);
bool* out_data_bool = out.data<bool>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_bool[i], static_cast<bool>(ptr[i]));
}
// transform float to bfloat16
float* in_data_float =
in.mutable_data<float>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_float[i] = i;
}
paddle::framework::TransDataType(kernel_fp32, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_float[i]).x);
}
// transform double to bfloat16
double* in_data_double =
in.mutable_data<double>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_double[i] = i;
}
paddle::framework::TransDataType(kernel_fp64, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_double[i]).x);
}
// transform int to bfloat16
int* in_data_int =
in.mutable_data<int>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_int[i] = i;
}
paddle::framework::TransDataType(kernel_int32, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_int[i]).x);
}
// transform int64 to bfloat16
int64_t* in_data_int64 =
in.mutable_data<int64_t>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_int64[i] = i;
}
paddle::framework::TransDataType(kernel_int64, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_int64[i]).x);
}
// transform bool to bfloat16
bool* in_data_bool =
in.mutable_data<bool>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_bool[i] = i;
}
paddle::framework::TransDataType(kernel_bool, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_bool[i]).x);
}
}
} }
...@@ -74,6 +74,7 @@ set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto ...@@ -74,6 +74,7 @@ set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
eager_deletion_pass eager_deletion_pass
buffer_shared_inplace_op_pass buffer_shared_inplace_op_pass
buffer_shared_cross_op_memory_reuse_pass buffer_shared_cross_op_memory_reuse_pass
inplace_addto_op_pass
set_reader_device_info_utils set_reader_device_info_utils
add_reader_dependency_pass) add_reader_dependency_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/all_reduce_op_handle.h" #include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h" #include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h" #include "paddle/fluid/framework/details/variable_visitor.h"
...@@ -34,14 +36,24 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, ...@@ -34,14 +36,24 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLCommunicator *ctxs) const platform::NCCLCommunicator *ctxs)
: NCCLOpHandleBase(node, places, ctxs), local_scopes_(local_scopes) { : NCCLOpHandleBase(node, places, ctxs), local_scopes_(local_scopes) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size(),
platform::errors::InvalidArgument(
"The number of places and the number of local scopes "
"should be equal, but got number of places is %d and "
"number of local scopes is %d.",
places_.size(), local_scopes_.size()));
} }
#else #else
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places) const std::vector<platform::Place> &places)
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) { : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size(),
platform::errors::InvalidArgument(
"The number of places and the number of local scopes "
"should be equal, but got number of places is %d and "
"number of local scopes is %d.",
places_.size(), local_scopes_.size()));
} }
#endif #endif
...@@ -60,13 +72,25 @@ void AllReduceOpHandle::AllReduceImpl( ...@@ -60,13 +72,25 @@ void AllReduceOpHandle::AllReduceImpl(
const std::vector<VarHandle *> &in_var_handles, const std::vector<VarHandle *> &in_var_handles,
const std::vector<VarHandle *> &out_var_handles) { const std::vector<VarHandle *> &out_var_handles) {
size_t num_places = places_.size(); size_t num_places = places_.size();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(in_var_handles.size(), num_places,
in_var_handles.size(), num_places, platform::errors::InvalidArgument(
"The NoDummyInputSize should be equal to the number of places."); "The NoDummyInputSize should be equal "
"to the number of places, but got NoDummyInputSize is "
"%d and the number of place is %d.",
in_var_handles.size(), num_places));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(), in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal."); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(local_exec_scopes_.size(), num_places); "The NoDummyInputSize and NoDummyOutputSize should be "
"equal, but got NoDummyInputSize is %d and NoDummyOutputSize is %d.",
in_var_handles.size(), out_var_handles.size()));
PADDLE_ENFORCE_EQ(
local_exec_scopes_.size(), num_places,
platform::errors::InvalidArgument(
"The number of local scopes should be equal "
"to the number of places, but got the number of local scopes is "
"%d and the number of place is %d.",
in_var_handles.size(), num_places));
std::vector<const void *> lod_tensor_data; std::vector<const void *> lod_tensor_data;
std::vector<platform::Place> places; std::vector<platform::Place> places;
...@@ -78,23 +102,36 @@ void AllReduceOpHandle::AllReduceImpl( ...@@ -78,23 +102,36 @@ void AllReduceOpHandle::AllReduceImpl(
for (size_t i = 0; i < local_exec_scopes_.size(); ++i) { for (size_t i = 0; i < local_exec_scopes_.size(); ++i) {
auto &local_scope = local_exec_scopes_[i]; auto &local_scope = local_exec_scopes_[i];
auto var = local_scope->FindVar(in_var_handles[i]->name()); auto var = local_scope->FindVar(in_var_handles[i]->name());
PADDLE_ENFORCE_NOT_NULL(var, "%s is not found int scope.", PADDLE_ENFORCE_NOT_NULL(var, platform::errors::NotFound(
in_var_handles[i]->name()); "Variable %s is not found in local scope.",
in_var_handles[i]->name()));
auto &lod_tensor = var->Get<LoDTensor>(); auto &lod_tensor = var->Get<LoDTensor>();
if (i == 0) { if (i == 0) {
numel = static_cast<int64_t>(lod_tensor.numel()); numel = static_cast<int64_t>(lod_tensor.numel());
// only enforce place0, we will enforce other palce numel == place0 numel // only enforce place0, we will enforce other palce numel == place0 numel
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
numel, 0, platform::errors::InvalidArgument( numel, 0,
"The numel of tensos=[%s] must > 0. But now numel=[%d]", platform::errors::PreconditionNotMet(
"The numel of tensor %s should be > 0, but got numel is %d.",
in_var_handles[i]->name(), numel)); in_var_handles[i]->name(), numel));
dtype = lod_tensor.type(); dtype = lod_tensor.type();
is_gpu_place = platform::is_gpu_place(lod_tensor.place()); is_gpu_place = platform::is_gpu_place(lod_tensor.place());
} }
PADDLE_ENFORCE_EQ(numel, static_cast<int64_t>(lod_tensor.numel())); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(dtype, lod_tensor.type()); numel, static_cast<int64_t>(lod_tensor.numel()),
PADDLE_ENFORCE_EQ(is_gpu_place, platform::is_gpu_place(lod_tensor.place())); platform::errors::PreconditionNotMet(
"The size of tensors of the same variable in different local "
"scopes should be equal."));
PADDLE_ENFORCE_EQ(
dtype, lod_tensor.type(),
platform::errors::PreconditionNotMet(
"The dtype of tensors of the same variable in different local "
"scopes should be equal."));
PADDLE_ENFORCE_EQ(is_gpu_place, platform::is_gpu_place(lod_tensor.place()),
platform::errors::PreconditionNotMet(
"The place type of tensors of the same variable "
"in different local scopes should be equal."));
lod_tensor_data.emplace_back(lod_tensor.data<void>()); lod_tensor_data.emplace_back(lod_tensor.data<void>());
places.emplace_back(lod_tensor.place()); places.emplace_back(lod_tensor.place());
...@@ -102,8 +139,12 @@ void AllReduceOpHandle::AllReduceImpl( ...@@ -102,8 +139,12 @@ void AllReduceOpHandle::AllReduceImpl(
VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name() VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name()
<< ", out_name:" << out_var_handles[i]->name(); << ", out_name:" << out_var_handles[i]->name();
PADDLE_ENFORCE_EQ(in_var_handles[i]->name(), out_var_handles[i]->name(), PADDLE_ENFORCE_EQ(
"The name of input and output should be equal."); in_var_handles[i]->name(), out_var_handles[i]->name(),
platform::errors::InvalidArgument(
"The name of input and output of all_reduce op should be equal, "
"but got input is %s and output is %s.",
in_var_handles[i]->name(), out_var_handles[i]->name()));
} }
std::vector<std::string> grad_var_names; std::vector<std::string> grad_var_names;
...@@ -122,7 +163,9 @@ void AllReduceOpHandle::AllReduceFunc( ...@@ -122,7 +163,9 @@ void AllReduceOpHandle::AllReduceFunc(
const std::vector<std::string> &out_var_names) { const std::vector<std::string> &out_var_names) {
if (is_gpu_place(places[0])) { if (is_gpu_place(places[0])) {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
PADDLE_ENFORCE_NOT_NULL(nccl_ctxs_, "nccl_ctxs should not be nullptr."); PADDLE_ENFORCE_NOT_NULL(nccl_ctxs_,
platform::errors::InvalidArgument(
"The nccl context should not be NULL."));
ncclDataType_t nccl_dtype = platform::ToNCCLDataType(dtype); ncclDataType_t nccl_dtype = platform::ToNCCLDataType(dtype);
std::vector<std::function<void()>> all_reduce_calls; std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_exec_scopes_.size(); ++i) { for (size_t i = 0; i < local_exec_scopes_.size(); ++i) {
...@@ -134,7 +177,8 @@ void AllReduceOpHandle::AllReduceFunc( ...@@ -134,7 +177,8 @@ void AllReduceOpHandle::AllReduceFunc(
} }
NCCLAllReduceFunc(all_reduce_calls); NCCLAllReduceFunc(all_reduce_calls);
#else #else
PADDLE_THROW("Not compiled with CUDA."); PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with CUDA."));
#endif #endif
} else { // Special handle CPU only Operator's gradient. Like CRF } else { // Special handle CPU only Operator's gradient. Like CRF
auto &trg = *local_exec_scopes_[0] auto &trg = *local_exec_scopes_[0]
......
...@@ -89,8 +89,19 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( ...@@ -89,8 +89,19 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
places_(std::move(places)), places_(std::move(places)),
graphs_(std::move(graphs)) { graphs_(std::move(graphs)) {
VLOG(3) << "build AsyncSSAGraphExecutor"; VLOG(3) << "build AsyncSSAGraphExecutor";
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size(),
PADDLE_ENFORCE_EQ(local_scopes_.size(), local_exec_scopes_.size()); platform::errors::InvalidArgument(
"The number of places and the number of local scopes "
"should be equal, but got number of places is %d and "
"number of local scopes is %d.",
places_.size(), local_scopes_.size()));
PADDLE_ENFORCE_EQ(
local_scopes_.size(), local_exec_scopes_.size(),
platform::errors::InvalidArgument(
"The number of local scopes and the number of local execution scopes "
"should be equal, but got number of local scopes is %d and "
"number of local execution scopes is %d.",
local_scopes_.size(), local_exec_scopes_.size()));
// set the correct size of thread pool to each device. // set the correct size of thread pool to each device.
strategy_.num_threads_ = strategy_.num_threads_ < places_.size() strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "boost/optional.hpp" #include "boost/optional.hpp"
#include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -119,6 +120,9 @@ struct BuildStrategy { ...@@ -119,6 +120,9 @@ struct BuildStrategy {
// Turn on inplace by default. // Turn on inplace by default.
bool enable_inplace_{true}; bool enable_inplace_{true};
// Turn off inplace addto by default.
bool enable_addto_{false};
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode, // FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if // num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model. // it's distributed model.
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include <deque> #include <deque>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/fetch_async_op_handle.h" #include "paddle/fluid/framework/details/fetch_async_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
...@@ -48,7 +50,9 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( ...@@ -48,7 +50,9 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
bootstrap_ops_.emplace_back(op); bootstrap_ops_.emplace_back(op);
} }
} }
PADDLE_ENFORCE_GT(op_deps_.size(), 0, "The graph doesn't have operators."); PADDLE_ENFORCE_GT(op_deps_.size(), 0,
platform::errors::PreconditionNotMet(
"The graph doesn't have operators."));
PrepareAtomicOpDeps(); PrepareAtomicOpDeps();
} }
......
...@@ -13,9 +13,11 @@ ...@@ -13,9 +13,11 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h"
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -138,8 +140,10 @@ void FetchOpHandle::RunImpl() { ...@@ -138,8 +140,10 @@ void FetchOpHandle::RunImpl() {
auto *var_handle = static_cast<VarHandle *>(inputs_[i]); auto *var_handle = static_cast<VarHandle *>(inputs_[i]);
auto &scope = scopes.at(var_handle->scope_idx()); auto &scope = scopes.at(var_handle->scope_idx());
auto *var = scope->FindVar(var_handle->name()); auto *var = scope->FindVar(var_handle->name());
PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope", PADDLE_ENFORCE_NOT_NULL(
var_handle->name()); var,
platform::errors::NotFound(
"Cannot find variable %s in execution scope.", var_handle->name()));
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
auto &t = var->Get<framework::LoDTensor>(); auto &t = var->Get<framework::LoDTensor>();
......
...@@ -167,6 +167,8 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num, ...@@ -167,6 +167,8 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num,
// more detail see: 180 page of // more detail see: 180 page of
// https://www.openmp.org/wp-content/uploads/OpenMP4.0.0.pdf // https://www.openmp.org/wp-content/uploads/OpenMP4.0.0.pdf
#pragma omp declare reduction(+ : paddle::platform::float16 : omp_out += omp_in) #pragma omp declare reduction(+ : paddle::platform::float16 : omp_out += omp_in)
#pragma omp declare reduction(+ : paddle::platform::bfloat16 : omp_out += \
omp_in)
#endif #endif
template <typename T> template <typename T>
...@@ -205,6 +207,21 @@ void CheckNanInf<paddle::platform::float16>( ...@@ -205,6 +207,21 @@ void CheckNanInf<paddle::platform::float16>(
PrintNanInf(value, numel, print_num, op_type, var_name); PrintNanInf(value, numel, print_num, op_type, var_name);
} }
} }
template <>
void CheckNanInf<paddle::platform::bfloat16>(
const paddle::platform::bfloat16* value, const size_t numel, int print_num,
const std::string& op_type, const std::string& var_name) {
float sum = 0.0f;
#pragma omp parallel for reduction(+ : sum)
for (size_t i = 0; i < numel; ++i) {
sum += static_cast<float>(value[i] - value[i]);
}
if (std::isnan(sum) || std::isinf(sum)) {
PrintNanInf(value, numel, print_num, op_type, var_name);
}
}
#endif #endif
template <> template <>
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include <map> #include <map>
#include <unordered_set> #include <unordered_set>
...@@ -88,6 +89,12 @@ void OpHandleBase::Run(bool use_cuda) { ...@@ -88,6 +89,12 @@ void OpHandleBase::Run(bool use_cuda) {
PADDLE_ENFORCE(!use_cuda); PADDLE_ENFORCE(!use_cuda);
#endif #endif
// skip running current op, used with inplace_addto_op_pass
if (skip_running_) {
VLOG(4) << "skip running: " << Name();
return;
}
RunImpl(); RunImpl();
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -52,6 +53,10 @@ class OpHandleBase { ...@@ -52,6 +53,10 @@ class OpHandleBase {
virtual Priority GetPriority() const { return kNormal; } virtual Priority GetPriority() const { return kNormal; }
virtual bool GetSkipRunning() const { return skip_running_; }
virtual void SetSkipRunning(bool skip_runing) { skip_running_ = skip_runing; }
virtual std::string Name() const = 0; virtual std::string Name() const = 0;
void Run(bool use_cuda); void Run(bool use_cuda);
...@@ -131,6 +136,7 @@ class OpHandleBase { ...@@ -131,6 +136,7 @@ class OpHandleBase {
std::map<platform::Place, platform::DeviceContext *> dev_ctxes_; std::map<platform::Place, platform::DeviceContext *> dev_ctxes_;
std::vector<Scope *> local_exec_scopes_; std::vector<Scope *> local_exec_scopes_;
bool skip_running_ = false;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
std::unordered_map<int, cudaEvent_t> events_; std::unordered_map<int, cudaEvent_t> events_;
......
...@@ -13,9 +13,11 @@ ...@@ -13,9 +13,11 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
...@@ -104,7 +106,12 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( ...@@ -104,7 +106,12 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
places_(places), places_(places),
graphs_(std::move(graphs)), graphs_(std::move(graphs)),
feed_status_(places.size(), FeedStatus::kNone) { feed_status_(places.size(), FeedStatus::kNone) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size(),
platform::errors::InvalidArgument(
"The number of places and the number of local scopes "
"should be equal, but got number of places is %d and "
"number of local scopes is %d.",
places_.size(), local_scopes_.size()));
PADDLE_ENFORCE_EQ(places_.size(), graphs_.size(), PADDLE_ENFORCE_EQ(places_.size(), graphs_.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
......
...@@ -13,10 +13,12 @@ ...@@ -13,10 +13,12 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
...@@ -37,7 +39,13 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor( ...@@ -37,7 +39,13 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
var_infos_(std::move(var_infos)), var_infos_(std::move(var_infos)),
places_(std::move(places)), places_(std::move(places)),
scope_monitor_(places_, local_exec_scopes_) { scope_monitor_(places_, local_exec_scopes_) {
PADDLE_ENFORCE_EQ(local_scopes_.size(), local_exec_scopes_.size()); PADDLE_ENFORCE_EQ(
local_scopes_.size(), local_exec_scopes_.size(),
platform::errors::InvalidArgument(
"The number of local scopes and the number of local execution scopes "
"should be equal, but got number of local scopes is %d and "
"number of local execution scopes is %d.",
local_scopes_.size(), local_exec_scopes_.size()));
PrepareLocalExeScopes(); PrepareLocalExeScopes();
} }
......
...@@ -13,9 +13,11 @@ ...@@ -13,9 +13,11 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/share_tensor_buffer_functor.h" #include "paddle/fluid/framework/details/share_tensor_buffer_functor.h"
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -29,7 +31,8 @@ static inline const Tensor &GetTensorFromVar(const Variable *var) { ...@@ -29,7 +31,8 @@ static inline const Tensor &GetTensorFromVar(const Variable *var) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>(); return var->Get<LoDTensor>();
} else { } else {
PADDLE_THROW("Variable must be type of LoDTensor"); PADDLE_THROW(platform::errors::InvalidArgument(
"Variable must be type of LoDTensor."));
} }
} }
...@@ -37,20 +40,27 @@ static inline Tensor *GetMutableTensorFromVar(Variable *var) { ...@@ -37,20 +40,27 @@ static inline Tensor *GetMutableTensorFromVar(Variable *var) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>(); return var->GetMutable<LoDTensor>();
} else { } else {
PADDLE_THROW("Variable must be type of LoDTensor"); PADDLE_THROW(platform::errors::InvalidArgument(
"Variable must be type of LoDTensor."));
} }
} }
ShareTensorBufferFunctor::ShareTensorBufferFunctor( ShareTensorBufferFunctor::ShareTensorBufferFunctor(
Scope *scope, size_t scope_idx, const std::string &op_type, Scope *scope, size_t scope_idx, const std::string &op_type,
const std::vector<const ir::MemOptVarInfo *> &in_var_infos, const std::vector<const ir::MemOptVarInfo *> &in_var_infos,
const std::vector<std::string> &out_var_names) const std::vector<std::string> &out_var_names, bool share_dims)
: scope_(scope), : scope_(scope),
scope_idx_(scope_idx), scope_idx_(scope_idx),
op_type_(op_type), op_type_(op_type),
in_var_infos_(in_var_infos), in_var_infos_(in_var_infos),
out_var_names_(out_var_names) { out_var_names_(out_var_names),
PADDLE_ENFORCE_EQ(in_var_infos_.size(), out_var_names_.size()); share_dims_(share_dims) {
PADDLE_ENFORCE_EQ(in_var_infos_.size(), out_var_names_.size(),
platform::errors::PreconditionNotMet(
"The number of input variables and output variables "
"should be equal, but got number of input variables is "
"%d and number of output variables is %d.",
in_var_infos_.size(), out_var_names_.size()));
for (size_t i = 0; i < in_var_infos_.size(); ++i) { for (size_t i = 0; i < in_var_infos_.size(); ++i) {
AddReuseVarPair(in_var_infos_[i], out_var_names_[i]); AddReuseVarPair(in_var_infos_[i], out_var_names_[i]);
} }
...@@ -67,32 +77,59 @@ ShareTensorBufferFunctor::ReusedVars() const { ...@@ -67,32 +77,59 @@ ShareTensorBufferFunctor::ReusedVars() const {
void ShareTensorBufferFunctor::AddReuseVarPair( void ShareTensorBufferFunctor::AddReuseVarPair(
const ir::MemOptVarInfo *in_var_info, const std::string &out_var_name) { const ir::MemOptVarInfo *in_var_info, const std::string &out_var_name) {
PADDLE_ENFORCE_NOT_NULL(in_var_info, "in_var_info cannot be nullptr"); PADDLE_ENFORCE_NOT_NULL(
in_var_info,
platform::errors::InvalidArgument(
"The input variables to be inplaced should not be NULL."));
PADDLE_ENFORCE_NE(in_var_info->Name(), out_var_name, PADDLE_ENFORCE_NE(in_var_info->Name(), out_var_name,
"in/out cannot have same name: %s", out_var_name); platform::errors::InvalidArgument(
"The input variable and output variable to be inplaced "
"cannot have the same name: %s.",
out_var_name));
in_var_infos_.emplace_back(in_var_info); in_var_infos_.emplace_back(in_var_info);
out_var_names_.emplace_back(out_var_name); out_var_names_.emplace_back(out_var_name);
} }
void ShareTensorBufferFunctor::CallOnce() { void ShareTensorBufferFunctor::CallOnce() {
PADDLE_ENFORCE(in_out_vars_.empty(), "in_out_vars_ must be initialized here"); PADDLE_ENFORCE(in_out_vars_.empty(),
platform::errors::InvalidArgument(
"The input-output variable pairs to be "
"inplaced should be initialized here."));
for (size_t i = 0; i < in_var_infos_.size(); ++i) { for (size_t i = 0; i < in_var_infos_.size(); ++i) {
auto *in_var = exec_scope_->FindVar(in_var_infos_[i]->Name()); auto *in_var = exec_scope_->FindVar(in_var_infos_[i]->Name());
auto *out_var = exec_scope_->FindVar(out_var_names_[i]); auto *out_var = exec_scope_->FindVar(out_var_names_[i]);
PADDLE_ENFORCE_NOT_NULL(in_var); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE_NOT_NULL(out_var); in_var, platform::errors::NotFound(
PADDLE_ENFORCE_NE(in_var, out_var); "The input variable(%s)to be inplaced should not be NULL.",
in_var_infos_[i]->Name()));
PADDLE_ENFORCE_NOT_NULL(
out_var,
platform::errors::NotFound(
"The output variable(%s) to be inplaced should not be NULL.",
out_var_names_[i]));
PADDLE_ENFORCE_NE(
in_var, out_var,
platform::errors::PreconditionNotMet(
"The input variable and output variable to be inplaced "
"cannot be the same variable(%s).",
out_var_names_[i]));
in_out_vars_.emplace_back(in_var, out_var); in_out_vars_.emplace_back(in_var, out_var);
} }
} }
void ShareTensorBufferFunctor::operator()(Scope *exec_scope) { void ShareTensorBufferFunctor::operator()(Scope *exec_scope) {
if (!exec_scope_) { if (!exec_scope_) {
PADDLE_ENFORCE_NOT_NULL(exec_scope); PADDLE_ENFORCE_NOT_NULL(exec_scope,
platform::errors::InvalidArgument(
"The given execution scope should not be NULL "
"if the cached scope is NULL."));
exec_scope_ = exec_scope; exec_scope_ = exec_scope;
CallOnce(); CallOnce();
} else { } else {
PADDLE_ENFORCE(exec_scope_ == exec_scope, "Scope must be the same"); PADDLE_ENFORCE_EQ(exec_scope_, exec_scope,
platform::errors::InvalidArgument(
"The given execution scope and the cached execution "
"scope should be the same."));
} }
for (size_t i = 0; i < in_var_infos_.size(); ++i) { for (size_t i = 0; i < in_var_infos_.size(); ++i) {
...@@ -115,6 +152,13 @@ void ShareTensorBufferFunctor::operator()(Scope *exec_scope) { ...@@ -115,6 +152,13 @@ void ShareTensorBufferFunctor::operator()(Scope *exec_scope) {
} else { } else {
out_tensor->ShareBufferWith(in_tensor); out_tensor->ShareBufferWith(in_tensor);
// NOTE(zhiqiu): In the case of inplace addto, if the operator of
// the in_out_vars is skipped during running, we should set the dims of
// output as the same as input.
if (share_dims_) {
out_tensor->Resize(in_tensor.dims());
}
VLOG(2) << "Share tensor buffer when running " << op_type_ << " : " VLOG(2) << "Share tensor buffer when running " << op_type_ << " : "
<< in_var_info->Name() << " -> " << out_var_names_[i]; << in_var_info->Name() << " -> " << out_var_names_[i];
} }
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -40,11 +41,13 @@ class ShareTensorBufferFunctor { ...@@ -40,11 +41,13 @@ class ShareTensorBufferFunctor {
ShareTensorBufferFunctor( ShareTensorBufferFunctor(
Scope *scope, size_t scope_idx, const std::string &op_type, Scope *scope, size_t scope_idx, const std::string &op_type,
const std::vector<const ir::MemOptVarInfo *> &in_var_infos, const std::vector<const ir::MemOptVarInfo *> &in_var_infos,
const std::vector<std::string> &out_var_names); const std::vector<std::string> &out_var_names, bool share_dims = false);
void AddReuseVarPair(const ir::MemOptVarInfo *in_var_info, void AddReuseVarPair(const ir::MemOptVarInfo *in_var_info,
const std::string &out_var_name); const std::string &out_var_name);
void SetShareDims(bool share_dims) { share_dims_ = share_dims; }
void operator()(Scope *exec_scope); void operator()(Scope *exec_scope);
std::unordered_map<std::string, std::string> ReusedVars() const; std::unordered_map<std::string, std::string> ReusedVars() const;
...@@ -66,6 +69,11 @@ class ShareTensorBufferFunctor { ...@@ -66,6 +69,11 @@ class ShareTensorBufferFunctor {
std::vector<std::string> out_var_names_; std::vector<std::string> out_var_names_;
std::vector<std::pair<const Variable *, Variable *>> in_out_vars_; std::vector<std::pair<const Variable *, Variable *>> in_out_vars_;
// NOTE(zhiqiu): In the case of inplace addto, if the operator of
// the in_out_vars is skipped during running, we should set the dims of output
// as the same as input.
bool share_dims_{false};
}; };
} // namespace details } // namespace details
......
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h" #include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h"
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -32,26 +34,35 @@ ComputationOpHandle *GetUniquePendingComputationOpHandle( ...@@ -32,26 +34,35 @@ ComputationOpHandle *GetUniquePendingComputationOpHandle(
for (ir::Node *pending_op : out_var->outputs) { for (ir::Node *pending_op : out_var->outputs) {
auto &op = pending_op->Wrapper<OpHandleBase>(); auto &op = pending_op->Wrapper<OpHandleBase>();
auto *compute_op = dynamic_cast<ComputationOpHandle *>(&op); auto *compute_op = dynamic_cast<ComputationOpHandle *>(&op);
PADDLE_ENFORCE_NOT_NULL(compute_op); PADDLE_ENFORCE_NOT_NULL(
compute_op,
platform::errors::PreconditionNotMet(
"The pending OpHandle should be ComputationOpHandle."));
if (result_op == nullptr) { if (result_op == nullptr) {
result_op = compute_op; result_op = compute_op;
} else { } else {
PADDLE_ENFORCE_EQ(result_op, compute_op); PADDLE_ENFORCE_EQ(
result_op, compute_op,
platform::errors::PreconditionNotMet(
"The pending OpHandle should be the unique one."));
} }
} }
} }
PADDLE_ENFORCE_NOT_NULL(result_op); PADDLE_ENFORCE_NOT_NULL(result_op,
platform::errors::PreconditionNotMet(
"The pending OpHandle should not be NULL."));
return result_op; return result_op;
} }
ShareTensorBufferOpHandle::ShareTensorBufferOpHandle( ShareTensorBufferOpHandle::ShareTensorBufferOpHandle(
ir::Node *node, Scope *scope, size_t scope_idx, const std::string &op_type, ir::Node *node, Scope *scope, size_t scope_idx, const std::string &op_type,
const std::vector<const ir::MemOptVarInfo *> &in_var_infos, const std::vector<const ir::MemOptVarInfo *> &in_var_infos,
const std::vector<std::string> &out_var_names) const std::vector<std::string> &out_var_names, bool share_dims)
: OpHandleBase(node), : OpHandleBase(node),
functor_(scope, scope_idx, op_type, in_var_infos, out_var_names) {} functor_(scope, scope_idx, op_type, in_var_infos, out_var_names,
share_dims) {}
std::unordered_map<std::string, std::string> std::unordered_map<std::string, std::string>
ShareTensorBufferOpHandle::ReusedVars() const { ShareTensorBufferOpHandle::ReusedVars() const {
...@@ -63,6 +74,10 @@ void ShareTensorBufferOpHandle::AddReuseVarPair( ...@@ -63,6 +74,10 @@ void ShareTensorBufferOpHandle::AddReuseVarPair(
functor_.AddReuseVarPair(in_var_info, out_var_name); functor_.AddReuseVarPair(in_var_info, out_var_name);
} }
void ShareTensorBufferOpHandle::SetShareDims(bool share_dims) {
functor_.SetShareDims(share_dims);
}
void ShareTensorBufferOpHandle::InitCUDA() { void ShareTensorBufferOpHandle::InitCUDA() {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
int dev_id = int dev_id =
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_functor.h" #include "paddle/fluid/framework/details/share_tensor_buffer_functor.h"
...@@ -31,7 +32,7 @@ class ShareTensorBufferOpHandle : public OpHandleBase { ...@@ -31,7 +32,7 @@ class ShareTensorBufferOpHandle : public OpHandleBase {
ir::Node *node, Scope *scope, size_t scope_idx, ir::Node *node, Scope *scope, size_t scope_idx,
const std::string &op_type, const std::string &op_type,
const std::vector<const ir::MemOptVarInfo *> &in_vars_infos, const std::vector<const ir::MemOptVarInfo *> &in_vars_infos,
const std::vector<std::string> &out_var_names); const std::vector<std::string> &out_var_names, bool share_dims = false);
std::unordered_map<std::string, std::string> ReusedVars() const; std::unordered_map<std::string, std::string> ReusedVars() const;
...@@ -42,6 +43,8 @@ class ShareTensorBufferOpHandle : public OpHandleBase { ...@@ -42,6 +43,8 @@ class ShareTensorBufferOpHandle : public OpHandleBase {
void AddReuseVarPair(const ir::MemOptVarInfo *in_var_info, void AddReuseVarPair(const ir::MemOptVarInfo *in_var_info,
const std::string &out_var_name); const std::string &out_var_name);
void SetShareDims(bool share_dims);
const ShareTensorBufferFunctor &Functor() const { return functor_; } const ShareTensorBufferFunctor &Functor() const { return functor_; }
protected: protected:
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_executor.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fetch_async_op_handle.h" #include "paddle/fluid/framework/details/fetch_async_op_handle.h"
namespace paddle { namespace paddle {
...@@ -27,8 +28,9 @@ void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops) { ...@@ -27,8 +28,9 @@ void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops) {
PADDLE_ENFORCE_EQ(dynamic_cast<FetchOpHandle*>(op) != nullptr || PADDLE_ENFORCE_EQ(dynamic_cast<FetchOpHandle*>(op) != nullptr ||
dynamic_cast<FetchAsyncOpHandle*>(op) != nullptr, dynamic_cast<FetchAsyncOpHandle*>(op) != nullptr,
true, true,
platform::errors::PreconditionNotMet(
"The input ops of ClearFetchOp function should be " "The input ops of ClearFetchOp function should be "
"FetchOpHandle or FetchAsyncOpHandle."); "FetchOpHandle or FetchAsyncOpHandle."));
for (auto& out_var : op->Node()->outputs) { for (auto& out_var : op->Node()->outputs) {
graph->RemoveNode(out_var); graph->RemoveNode(out_var);
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -138,7 +139,10 @@ inline FetchResultType ThreadedSSAGraphExecutor::RunImpl( ...@@ -138,7 +139,10 @@ inline FetchResultType ThreadedSSAGraphExecutor::RunImpl(
} }
} }
} }
PADDLE_ENFORCE(ready_ops.empty()); PADDLE_ENFORCE_EQ(
ready_ops.empty(), true,
platform::errors::Fatal("After the execution of computation graph, "
"there are unexecuted operators left."));
} }
// Wait FetchOps. // Wait FetchOps.
...@@ -165,9 +169,8 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -165,9 +169,8 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
FetchResultType *fetch_data, bool return_merged) { FetchResultType *fetch_data, bool return_merged) {
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::unordered_set<VarHandleBase *> local_ready_vars; std::unordered_set<VarHandleBase *> local_ready_vars;
std::unordered_set<std::string> fetch_tensor_set(fetch_tensors.begin(),
fetch_tensors.end()); for (auto &fetch_var_name : fetch_tensors) {
for (auto &fetch_var_name : fetch_tensor_set) {
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
...@@ -231,7 +234,11 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -231,7 +234,11 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
ready_ops->insert(static_cast<OpHandleBase *>(op)); ready_ops->insert(static_cast<OpHandleBase *>(op));
} }
} }
PADDLE_ENFORCE_EQ(local_ready_vars.size(), 0); PADDLE_ENFORCE_EQ(
local_ready_vars.size(), 0,
platform::errors::Fatal(
"The number of ready variables should be 0, but got %d.",
local_ready_vars.size()));
} }
void ThreadedSSAGraphExecutor::InsertPendingOp( void ThreadedSSAGraphExecutor::InsertPendingOp(
...@@ -277,7 +284,9 @@ void ThreadedSSAGraphExecutor::PrepareOpDeps() { ...@@ -277,7 +284,9 @@ void ThreadedSSAGraphExecutor::PrepareOpDeps() {
} }
} }
op_deps_->num_ops_ = ready_ops.size() + pending_ops.size(); op_deps_->num_ops_ = ready_ops.size() + pending_ops.size();
PADDLE_ENFORCE_GT(op_deps_->num_ops_, 0, "The graph doesn't have operators."); PADDLE_ENFORCE_GT(
op_deps_->num_ops_, 0,
platform::errors::InvalidArgument("The graph doesn't have operators."));
for (auto ready_var : ready_vars) { for (auto ready_var : ready_vars) {
pending_vars.erase(ready_var); pending_vars.erase(ready_var);
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <ThreadPool.h> // ThreadPool in thrird party
#include <deque> #include <deque>
#include <functional> #include <functional>
#include <list> #include <list>
...@@ -24,8 +26,6 @@ ...@@ -24,8 +26,6 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <ThreadPool.h> // ThreadPool in thrird party
#include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/exception_holder.h" #include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h"
......
...@@ -54,8 +54,10 @@ struct VarHandleBase { ...@@ -54,8 +54,10 @@ struct VarHandleBase {
void AddOutput(OpHandleBase* out, ir::Node* node) { void AddOutput(OpHandleBase* out, ir::Node* node) {
if (pending_ops_.find(out) == pending_ops_.end()) { if (pending_ops_.find(out) == pending_ops_.end()) {
PADDLE_ENFORCE(out != nullptr, "The output of %s should not be nullptr", PADDLE_ENFORCE_NOT_NULL(out,
this->Node()->Name()); platform::errors::InvalidArgument(
"The output added to VarHandle %s is NULL.",
this->Node()->Name()));
pending_ops_.insert(out); pending_ops_.insert(out);
node_->outputs.push_back(node); node_->outputs.push_back(node);
} }
...@@ -120,7 +122,10 @@ struct VarHandle : public VarHandleBase { ...@@ -120,7 +122,10 @@ struct VarHandle : public VarHandleBase {
bool HasEvent() { return has_event_; } bool HasEvent() { return has_event_; }
const cudaEvent_t& GetEvent() { const cudaEvent_t& GetEvent() {
PADDLE_ENFORCE(HasEvent(), "The event is not set."); PADDLE_ENFORCE_EQ(
HasEvent(), true,
platform::errors::PreconditionNotMet(
"The cuda event is not set, maybe InitCUDA() is not called."));
return event_; return event_;
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/variable_visitor.h" #include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -24,7 +25,9 @@ static void VisitVariable(Variable* var, Func* func) { ...@@ -24,7 +25,9 @@ static void VisitVariable(Variable* var, Func* func) {
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
(*func)(var->GetMutable<SelectedRows>()); (*func)(var->GetMutable<SelectedRows>());
} else { } else {
PADDLE_THROW("Not supported type %s", ToTypeName(var->Type())); PADDLE_THROW(platform::errors::Unimplemented(
"VisitVariable is not supported for type %s.",
ToTypeName(var->Type())));
} }
} }
...@@ -35,7 +38,8 @@ static void VisitVariable(const Variable& var, Func* func) { ...@@ -35,7 +38,8 @@ static void VisitVariable(const Variable& var, Func* func) {
} else if (var.IsType<SelectedRows>()) { } else if (var.IsType<SelectedRows>()) {
(*func)(var.Get<SelectedRows>()); (*func)(var.Get<SelectedRows>());
} else { } else {
PADDLE_THROW("Not supported type %s", ToTypeName(var.Type())); PADDLE_THROW(platform::errors::Unimplemented(
"VisitVariable is not supported for type %s.", ToTypeName(var.Type())));
} }
} }
...@@ -50,7 +54,8 @@ struct TensorVisitor { ...@@ -50,7 +54,8 @@ struct TensorVisitor {
template <typename T> template <typename T>
void operator()() { void operator()() {
PADDLE_THROW("Not Support to get LoDTensor from %s", typeid(T).name()); PADDLE_THROW(platform::errors::Unimplemented(
"Getting tensor from type %s is not supported.", typeid(T).name()));
} }
}; };
...@@ -78,8 +83,8 @@ struct ShareDimsAndLoDVisitor { ...@@ -78,8 +83,8 @@ struct ShareDimsAndLoDVisitor {
template <typename T> template <typename T>
void operator()(const T&) { void operator()(const T&) {
PADDLE_ENFORCE("ShareDimsAndLoD is not supported by type %s", PADDLE_THROW(platform::errors::Unimplemented(
typeid(T).name()); "ShareDimsAndLoD is not supported for type %s.", typeid(T).name()));
} }
}; };
...@@ -89,42 +94,54 @@ void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) { ...@@ -89,42 +94,54 @@ void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
} }
struct EnforceShapeAndDTypeEQVisitor { struct EnforceShapeAndDTypeEQVisitor {
const Variable* trg_; const Variable* dst_;
void operator()(const LoDTensor& src) { void operator()(const LoDTensor& src) {
auto& tensor = trg_->Get<LoDTensor>(); auto& tensor = dst_->Get<LoDTensor>();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(src.place().which(), tensor.place().which(),
src.place().which(), tensor.place().which(), platform::errors::PreconditionNotMet(
"The Places of the two Variable must be all on CPU or all on GPU."); "The place type of the two variables is not equal."));
PADDLE_ENFORCE_EQ(src.type(), tensor.type(), PADDLE_ENFORCE_EQ(src.type(), tensor.type(),
"The dtype of the two Variable is not equal."); platform::errors::PreconditionNotMet(
PADDLE_ENFORCE_EQ(src.dims(), tensor.dims(), "The dtype of the two variables is not equal."));
"The dims of the two Variable is not equal."); PADDLE_ENFORCE_EQ(
src.dims(), tensor.dims(),
platform::errors::PreconditionNotMet(
"The layout of the two variables' tensors is not equal."));
PADDLE_ENFORCE_EQ(src.lod(), tensor.lod(), PADDLE_ENFORCE_EQ(src.lod(), tensor.lod(),
"The lod of the two Variable is not equal."); platform::errors::PreconditionNotMet(
PADDLE_ENFORCE_EQ(src.layout(), tensor.layout(), "The lod of the two variable is not equal."));
"The layout of the two Variable's tensor is not equal."); PADDLE_ENFORCE_EQ(
src.layout(), tensor.layout(),
platform::errors::PreconditionNotMet(
"The layout of the two variables' tensors tensor is not equal."));
} }
void operator()(const SelectedRows& src) { void operator()(const SelectedRows& src) {
auto& selected_rows = trg_->Get<SelectedRows>(); auto& selected_rows = dst_->Get<SelectedRows>();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(src.place().which(), selected_rows.place().which(),
src.place().which(), selected_rows.place().which(), platform::errors::PreconditionNotMet(
"The Places of the two Variable must be all on CPU or all on GPU."); "The place type of the two variables is not equal."));
PADDLE_ENFORCE_EQ(src.value().type(), selected_rows.value().type(), PADDLE_ENFORCE_EQ(src.value().type(), selected_rows.value().type(),
"The dtype of the two Variable is not equal."); platform::errors::PreconditionNotMet(
PADDLE_ENFORCE_EQ(src.value().layout(), selected_rows.value().layout(), "The dtype of the two variables is not equal."));
"The layout of the two Variable's tensor is not equal."); PADDLE_ENFORCE_EQ(
src.value().layout(), selected_rows.value().layout(),
platform::errors::PreconditionNotMet(
"The layout of the two variables' tensors is not equal."));
PADDLE_ENFORCE_EQ(src.height(), selected_rows.height(), PADDLE_ENFORCE_EQ(src.height(), selected_rows.height(),
"The height of the two Variable is not equal."); platform::errors::PreconditionNotMet(
"The height of the two variables is not equal."));
PADDLE_ENFORCE_EQ(src.GetCompleteDims(), selected_rows.GetCompleteDims(), PADDLE_ENFORCE_EQ(src.GetCompleteDims(), selected_rows.GetCompleteDims(),
"The dims of the two Variable is not equal."); platform::errors::PreconditionNotMet(
"The dims of the two variables is not equal."));
} }
template <typename T> template <typename T>
void operator()(const T&) { void operator()(const T&) {
PADDLE_ENFORCE("EnforceShapeAndDTypeEQ is not supported by type %s", PADDLE_THROW(platform::errors::Unimplemented(
typeid(T).name()); "EnforceShapeAndDTypeEQ is not supported for type %s.",
typeid(T).name()));
} }
}; };
......
...@@ -441,6 +441,7 @@ class SectionWorker : public DeviceWorker { ...@@ -441,6 +441,7 @@ class SectionWorker : public DeviceWorker {
skip_vars_ = skip_vars; skip_vars_ = skip_vars;
} }
static void ResetBatchId() { batch_id_ = 0; } static void ResetBatchId() { batch_id_ = 0; }
static void ResetThreadCompletedFlag() { threads_completed = false; }
static std::atomic<int> cpu_id_; static std::atomic<int> cpu_id_;
......
...@@ -36,7 +36,15 @@ message AMPConfig { ...@@ -36,7 +36,15 @@ message AMPConfig {
repeated string custom_black_varnames = 9; repeated string custom_black_varnames = 9;
} }
message LocalSGDConfig { optional int32 k_steps = 1 [ default = 4 ]; } message LocalSGDConfig {
optional int32 k_steps = 1 [ default = 1 ];
optional int32 begin_step = 2 [ default = 1 ];
}
message AdaptiveLocalSGDConfig {
optional int32 init_k_steps = 1 [ default = 1 ];
optional int32 begin_step = 2 [ default = 1 ];
}
message GradientMergeConfig { message GradientMergeConfig {
optional int32 k_steps = 1 [ default = 1 ]; optional int32 k_steps = 1 [ default = 1 ];
...@@ -52,6 +60,8 @@ message DGCConfig { ...@@ -52,6 +60,8 @@ message DGCConfig {
message LarsConfig { message LarsConfig {
optional float lars_coeff = 1 [ default = 0.001 ]; optional float lars_coeff = 1 [ default = 0.001 ];
optional float lars_weight_decay = 2 [ default = 0.0005 ]; optional float lars_weight_decay = 2 [ default = 0.0005 ];
optional float epsilon = 3 [ default = 0.0 ];
repeated string exclude_from_weight_decay = 4;
} }
message LambConfig { message LambConfig {
...@@ -116,6 +126,7 @@ message DistributedStrategy { ...@@ -116,6 +126,7 @@ message DistributedStrategy {
optional bool cudnn_exhaustive_search = 21 [ default = true ]; optional bool cudnn_exhaustive_search = 21 [ default = true ];
optional int32 conv_workspace_size_limit = 22 [ default = 4000 ]; optional int32 conv_workspace_size_limit = 22 [ default = 4000 ];
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ]; optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ];
optional bool adaptive_localsgd = 24 [ default = false ];
optional RecomputeConfig recompute_configs = 101; optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102; optional AMPConfig amp_configs = 102;
...@@ -126,6 +137,7 @@ message DistributedStrategy { ...@@ -126,6 +137,7 @@ message DistributedStrategy {
optional AsyncConfig a_sync_configs = 107; optional AsyncConfig a_sync_configs = 107;
optional LarsConfig lars_configs = 108; optional LarsConfig lars_configs = 108;
optional LambConfig lamb_configs = 109; optional LambConfig lamb_configs = 109;
optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110;
optional BuildStrategy build_strategy = 201; optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202; optional ExecutionStrategy execution_strategy = 202;
} }
......
...@@ -23,6 +23,7 @@ template <typename T> ...@@ -23,6 +23,7 @@ template <typename T>
static ::DLDataType GetDLDataTypeCode() { static ::DLDataType GetDLDataTypeCode() {
::DLDataType dtype; ::DLDataType dtype;
if (std::is_same<T, platform::float16>::value || if (std::is_same<T, platform::float16>::value ||
std::is_same<T, platform::bfloat16>::value ||
std::is_floating_point<T>::value) { std::is_floating_point<T>::value) {
dtype.code = kDLFloat; dtype.code = kDLFloat;
} else if (std::is_unsigned<T>::value) { } else if (std::is_unsigned<T>::value) {
......
...@@ -19,6 +19,8 @@ limitations under the License. */ ...@@ -19,6 +19,8 @@ limitations under the License. */
namespace gloo { namespace gloo {
namespace rendezvous { namespace rendezvous {
constexpr int kNodeSize = 136;
HdfsStore::HdfsStore(const std::string& path) { HdfsStore::HdfsStore(const std::string& path) {
path_ = path; path_ = path;
wait_sleep_ms_ = 10000; wait_sleep_ms_ = 10000;
...@@ -213,12 +215,14 @@ void ParallelConnectContext::connectFullMesh( ...@@ -213,12 +215,14 @@ void ParallelConnectContext::connectFullMesh(
storeKey << rank; storeKey << rank;
store.set(storeKey.str(), allBytes); store.set(storeKey.str(), allBytes);
auto total_add_size = kNodeSize * (size - 1);
std::vector<std::shared_ptr<std::thread>> connect_threads(thread_num_); std::vector<std::shared_ptr<std::thread>> connect_threads(thread_num_);
// Connect every pair // Connect every pair
for (uint32_t i = 0; i < connect_threads.size(); ++i) { for (uint32_t i = 0; i < connect_threads.size(); ++i) {
connect_threads[i].reset(new std::thread( connect_threads[i].reset(new std::thread(
[&store, &transportContext, this](size_t thread_idx, [&store, &transportContext, total_add_size, this](
size_t thread_num) -> void { size_t thread_idx, size_t thread_num) -> void {
for (int i = thread_idx; i < size; i += thread_num) { for (int i = thread_idx; i < size; i += thread_num) {
if (i == rank) { if (i == rank) {
continue; continue;
...@@ -226,8 +230,23 @@ void ParallelConnectContext::connectFullMesh( ...@@ -226,8 +230,23 @@ void ParallelConnectContext::connectFullMesh(
// Wait for address of other side of this pair to become available // Wait for address of other side of this pair to become available
std::string key = std::to_string(i); std::string key = std::to_string(i);
store.wait({key}, getTimeout()); store.wait({key}, getTimeout());
std::vector<char> allAddrs;
auto max_retry_times = 5;
// Connect to other side of this pair // Connect to other side of this pair
auto allAddrs = store.get(key);
while (max_retry_times > 0) {
allAddrs = store.get(key);
VLOG(3) << "store get all address size: " << allAddrs.size()
<< " except: " << total_add_size;
if (allAddrs.size() == static_cast<size_t>(total_add_size)) {
break;
}
--max_retry_times;
}
auto addr = extractAddress(allAddrs, i); auto addr = extractAddress(allAddrs, i);
transportContext->getPair(i)->connect(addr); transportContext->getPair(i)->connect(addr);
} }
......
...@@ -25,7 +25,7 @@ bool NCCLWrapper::is_initialized_ = false; ...@@ -25,7 +25,7 @@ bool NCCLWrapper::is_initialized_ = false;
void NCCLWrapper::InitNCCL() { void NCCLWrapper::InitNCCL() {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
PADDLE_ENFORCE(platform::dynload::ncclCommInitRank( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclCommInitRank(
&(nccl_info_.comm_), nccl_info_.global_ranks_, nccl_info_.nccl_id_, &(nccl_info_.comm_), nccl_info_.global_ranks_, nccl_info_.nccl_id_,
nccl_info_.my_global_rank_)); nccl_info_.my_global_rank_));
#endif #endif
...@@ -41,7 +41,8 @@ void NCCLWrapper::SetNCCLId(const NCCLInfo& nccl_info) { ...@@ -41,7 +41,8 @@ void NCCLWrapper::SetNCCLId(const NCCLInfo& nccl_info) {
NCCLInfo NCCLWrapper::GetNCCLId() { NCCLInfo NCCLWrapper::GetNCCLId() {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(&(nccl_info_.nccl_id_))); PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::ncclGetUniqueId(&(nccl_info_.nccl_id_)));
#endif #endif
return nccl_info_; return nccl_info_;
} }
...@@ -52,8 +53,8 @@ void NCCLWrapper::SetRankInfo(const int local_rank, const int global_rank, ...@@ -52,8 +53,8 @@ void NCCLWrapper::SetRankInfo(const int local_rank, const int global_rank,
nccl_info_.local_rank_ = local_rank; nccl_info_.local_rank_ = local_rank;
nccl_info_.my_global_rank_ = global_rank; nccl_info_.my_global_rank_ = global_rank;
nccl_info_.global_ranks_ = ranks; nccl_info_.global_ranks_ = ranks;
PADDLE_ENFORCE(cudaSetDevice(local_rank)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(local_rank));
PADDLE_ENFORCE(cudaStreamCreate(&(nccl_info_.stream_))); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&(nccl_info_.stream_)));
#endif #endif
return; return;
} }
...@@ -65,7 +66,7 @@ void NCCLWrapper::SyncVar(const int root_rank, const Scope& scope, ...@@ -65,7 +66,7 @@ void NCCLWrapper::SyncVar(const int root_rank, const Scope& scope,
auto var = scope.FindVar(name); auto var = scope.FindVar(name);
LoDTensor* tensor = var->GetMutable<LoDTensor>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
int32_t total_size = tensor->numel(); int32_t total_size = tensor->numel();
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast(
reinterpret_cast<void*>(tensor->data<float>()), total_size, ncclFloat, reinterpret_cast<void*>(tensor->data<float>()), total_size, ncclFloat,
root_rank, nccl_info_.comm_, nccl_info_.stream_)); root_rank, nccl_info_.comm_, nccl_info_.stream_));
cudaStreamSynchronize(nccl_info_.stream_); cudaStreamSynchronize(nccl_info_.stream_);
......
...@@ -21,10 +21,46 @@ limitations under the License. */ ...@@ -21,10 +21,46 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(int64_t device_id) {
#ifdef PADDLE_WITH_CUDA
static int64_t num_cuda_devices = -1;
static std::once_flag num_devices_init_flag;
static std::deque<std::once_flag> cuda_device_flags;
static std::vector<std::shared_ptr<Generator>> default_cuda_generators;
std::call_once(num_devices_init_flag, []() {
num_cuda_devices = paddle::platform::GetCUDADeviceCount();
cuda_device_flags.resize(num_cuda_devices);
default_cuda_generators.resize(num_cuda_devices);
});
if (device_id < 0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"cuda device id shoule be greater than 0"));
}
std::call_once(cuda_device_flags[device_id], [device_id]() {
default_cuda_generators[device_id] =
std::make_shared<Generator>(GetRandomSeed(), device_id);
VLOG(4) << "initial seed: "
<< default_cuda_generators[device_id]->GetCurrentSeed();
});
return default_cuda_generators[device_id];
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"getDefaultCUDAGenerator only support in CUDA place"));
#endif
}
const std::shared_ptr<Generator>& DefaultCPUGenerator() { const std::shared_ptr<Generator>& DefaultCPUGenerator() {
static auto default_cpu_generator = static auto default_cpu_generator =
std::make_shared<Generator>(GetRandomSeed()); std::make_shared<Generator>(GetRandomSeed());
...@@ -103,6 +139,7 @@ uint64_t Generator::Seed() { ...@@ -103,6 +139,7 @@ uint64_t Generator::Seed() {
void Generator::SetCurrentSeed(uint64_t seed) { void Generator::SetCurrentSeed(uint64_t seed) {
std::lock_guard<std::mutex> lock(this->mu_); std::lock_guard<std::mutex> lock(this->mu_);
this->state_.current_seed = seed; this->state_.current_seed = seed;
this->state_.thread_offset = 0;
std::seed_seq seq({seed}); std::seed_seq seq({seed});
this->engine_->seed(seq); this->engine_->seed(seq);
} }
...@@ -123,6 +160,22 @@ uint64_t Generator::Random64() { ...@@ -123,6 +160,22 @@ uint64_t Generator::Random64() {
return (*engine)(); return (*engine)();
} }
std::pair<uint64_t, uint64_t> Generator::IncrementOffset(
uint64_t increament_offset) {
uint64_t cur_offset = this->state_.thread_offset;
#ifdef PADDLE_WITH_CUDA
std::lock_guard<std::mutex> lock(this->mu_);
this->state_.thread_offset += increament_offset;
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Increment Offset only support in CUDA place"));
#endif
return std::make_pair(static_cast<int>(this->state_.current_seed),
cur_offset);
}
void Generator::SetIsInitPy(bool is_init_py) { void Generator::SetIsInitPy(bool is_init_py) {
this->is_init_py_ = is_init_py; this->is_init_py_ = is_init_py;
VLOG(4) << "SetIsInitPy:" << this->is_init_py_; VLOG(4) << "SetIsInitPy:" << this->is_init_py_;
......
...@@ -38,6 +38,7 @@ static uint64_t GetRandomSeed() { ...@@ -38,6 +38,7 @@ static uint64_t GetRandomSeed() {
struct GeneratorState { struct GeneratorState {
int64_t device = -1; int64_t device = -1;
uint64_t current_seed = 34342423252; uint64_t current_seed = 34342423252;
uint64_t thread_offset = 0;
std::mt19937_64 cpu_engine; std::mt19937_64 cpu_engine;
}; };
...@@ -49,6 +50,7 @@ struct Generator { ...@@ -49,6 +50,7 @@ struct Generator {
this->state_.cpu_engine = *engine; this->state_.cpu_engine = *engine;
this->state_.device = -1; this->state_.device = -1;
this->state_.current_seed = seed; this->state_.current_seed = seed;
this->state_.thread_offset = 0;
this->engine_ = engine; this->engine_ = engine;
VLOG(4) << "initial seed: " << this->state_.current_seed VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine; << ", cpu engine: " << &this->state_.cpu_engine;
...@@ -59,11 +61,25 @@ struct Generator { ...@@ -59,11 +61,25 @@ struct Generator {
this->state_.cpu_engine = *engine; this->state_.cpu_engine = *engine;
this->state_.device = -1; this->state_.device = -1;
this->state_.current_seed = seed; this->state_.current_seed = seed;
this->state_.thread_offset = 0;
this->engine_ = engine; this->engine_ = engine;
VLOG(4) << "initial seed: " << this->state_.current_seed VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine; << ", cpu engine: " << &this->state_.cpu_engine;
this->is_init_py_ = true; // TODO(zhiqiu): remove it in future this->is_init_py_ = true; // TODO(zhiqiu): remove it in future
} }
Generator(uint64_t seed, uint64_t device_id) {
std::seed_seq seq({seed});
auto engine = std::make_shared<std::mt19937_64>(seq);
this->state_.cpu_engine = *engine;
this->state_.device = device_id;
this->state_.current_seed = seed;
this->state_.thread_offset = 0;
this->engine_ = engine;
VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine;
this->is_init_py_ = false; // TODO(zhiqiu): remove it in future
}
Generator(const Generator& other) = delete; Generator(const Generator& other) = delete;
// get random state // get random state
...@@ -83,8 +99,11 @@ struct Generator { ...@@ -83,8 +99,11 @@ struct Generator {
uint64_t Random64(); uint64_t Random64();
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t increament_offset);
void SetIsInitPy(bool); void SetIsInitPy(bool);
bool GetIsInitPy() const; bool GetIsInitPy() const;
uint64_t get_device_id() { return this->state_.device; }
private: private:
GeneratorState state_; GeneratorState state_;
...@@ -105,5 +124,8 @@ std::shared_ptr<std::mt19937_64> OpDefaultCPUEngine(); ...@@ -105,5 +124,8 @@ std::shared_ptr<std::mt19937_64> OpDefaultCPUEngine();
std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t); std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);
const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(
int64_t device_id = -1);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -102,6 +102,8 @@ if(WITH_MKLDNN) ...@@ -102,6 +102,8 @@ if(WITH_MKLDNN)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_concat_relu_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(scale_matmul_fuse_pass inference DIR mkldnn) pass_library(scale_matmul_fuse_pass inference DIR mkldnn)
pass_library(cpu_bfloat16_placement_pass inference DIR mkldnn)
pass_library(cpu_bfloat16_pass inference DIR mkldnn)
pass_library(fc_mkldnn_pass inference DIR mkldnn) pass_library(fc_mkldnn_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn)
...@@ -162,4 +164,6 @@ endif() ...@@ -162,4 +164,6 @@ endif()
cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass) cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass)
cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass) cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass)
cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass)
cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass)
endif () endif ()
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -225,3 +226,14 @@ REGISTER_PASS(conv_affine_channel_fuse_pass, ...@@ -225,3 +226,14 @@ REGISTER_PASS(conv_affine_channel_fuse_pass,
paddle::framework::ir::ConvAffineChannelFusePass); paddle::framework::ir::ConvAffineChannelFusePass);
REGISTER_PASS(conv_eltwiseadd_affine_channel_fuse_pass, REGISTER_PASS(conv_eltwiseadd_affine_channel_fuse_pass,
paddle::framework::ir::ConvEltwiseAddAffineChannelFusePass); paddle::framework::ir::ConvEltwiseAddAffineChannelFusePass);
REGISTER_PASS_CAPABILITY(conv_affine_channel_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("affine_channel", 0));
REGISTER_PASS_CAPABILITY(conv_eltwiseadd_affine_channel_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("elementwise_add", 0)
.EQ("affine_channel", 0));
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -372,3 +373,14 @@ REGISTER_PASS(depthwise_conv_bn_fuse_pass, ...@@ -372,3 +373,14 @@ REGISTER_PASS(depthwise_conv_bn_fuse_pass,
paddle::framework::ir::DepthwiseConvBNFusePass); paddle::framework::ir::DepthwiseConvBNFusePass);
REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass, REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass,
paddle::framework::ir::DepthwiseConvEltwiseAddBNFusePass); paddle::framework::ir::DepthwiseConvEltwiseAddBNFusePass);
REGISTER_PASS_CAPABILITY(conv_bn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("batch_norm", 0));
REGISTER_PASS_CAPABILITY(conv_eltwiseadd_bn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("elementwise_add", 0)
.EQ("batch_norm", 0));
...@@ -11,9 +11,9 @@ ...@@ -11,9 +11,9 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h" #include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -116,3 +116,10 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -116,3 +116,10 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(conv_elementwise_add2_act_fuse_pass, REGISTER_PASS(conv_elementwise_add2_act_fuse_pass,
paddle::framework::ir::ConvElementwiseAdd2ActFusePass); paddle::framework::ir::ConvElementwiseAdd2ActFusePass);
REGISTER_PASS_CAPABILITY(conv_elementwise_add2_act_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("elementwise_add", 0)
.EQ("relu", 0)
.EQ("identity", 0));
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h" #include "paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -102,3 +103,10 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -102,3 +103,10 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(conv_elementwise_add_act_fuse_pass, REGISTER_PASS(conv_elementwise_add_act_fuse_pass,
paddle::framework::ir::ConvElementwiseAddActFusePass); paddle::framework::ir::ConvElementwiseAddActFusePass);
REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("elementwise_add", 0)
.EQ("relu", 0)
.EQ("identity", 0));
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h" #include "paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -89,3 +89,8 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -89,3 +89,8 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(conv_elementwise_add_fuse_pass, REGISTER_PASS(conv_elementwise_add_fuse_pass,
paddle::framework::ir::ConvElementwiseAddFusePass); paddle::framework::ir::ConvElementwiseAddFusePass);
REGISTER_PASS_CAPABILITY(conv_elementwise_add_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("elementwise_add", 0));
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -334,3 +335,8 @@ void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { ...@@ -334,3 +335,8 @@ void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
REGISTER_PASS(embedding_eltwise_layernorm_fuse_pass, REGISTER_PASS(embedding_eltwise_layernorm_fuse_pass,
paddle::framework::ir::EmbeddingEltwiseLayerNormFusePass); paddle::framework::ir::EmbeddingEltwiseLayerNormFusePass);
REGISTER_PASS_CAPABILITY(embedding_eltwise_layernorm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("lookup_table", 0)
.EQ("elementweise_add", 0));
...@@ -16,12 +16,13 @@ limitations under the License. */ ...@@ -16,12 +16,13 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
TEST(SkipLayerNormFusePass, basic) { TEST(EmbeddingElewiseLayernormFusePass, basic) {
// inputs operator output // inputs operator output
// -------------------------------------------------------------------- // --------------------------------------------------------------------
// (x, y) elementwise_add -> elementwise_out // (x, y) elementwise_add -> elementwise_out
...@@ -91,6 +92,12 @@ TEST(SkipLayerNormFusePass, basic) { ...@@ -91,6 +92,12 @@ TEST(SkipLayerNormFusePass, basic) {
"The number of fusion nodes does not meet expectations after fuse")); "The number of fusion nodes does not meet expectations after fuse"));
} }
TEST(EmbeddingElewiseLayernormFusePass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("embedding_eltwise_layernorm_fuse_pass"));
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
...@@ -34,7 +36,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -34,7 +36,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
// Build pattern // Build pattern
PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "x")) PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "x"))
->assert_is_op_input("lookup_table") ->assert_is_op_input("lookup_table_v2")
->assert_var_not_persistable(); ->assert_var_not_persistable();
patterns::Embedding embedding_pattern(pattern, name_scope); patterns::Embedding embedding_pattern(pattern, name_scope);
// TODO(jczaja): Intermediate can only be for val that are not used anywhere // TODO(jczaja): Intermediate can only be for val that are not used anywhere
...@@ -256,3 +258,11 @@ void EmbeddingFCLSTMFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -256,3 +258,11 @@ void EmbeddingFCLSTMFusePass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(embedding_fc_lstm_fuse_pass, REGISTER_PASS(embedding_fc_lstm_fuse_pass,
paddle::framework::ir::EmbeddingFCLSTMFusePass); paddle::framework::ir::EmbeddingFCLSTMFusePass);
REGISTER_PASS_CAPABILITY(embedding_fc_lstm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("lookup_table_v2", 0)
.EQ("mul", 0)
.EQ("elementwise_add", 0)
.EQ("lstm", 0)
.EQ("fused_embedding_fc_lstm", 0));
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -182,3 +183,10 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { ...@@ -182,3 +183,10 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
REGISTER_PASS(fc_fuse_pass, paddle::framework::ir::FCFusePass) REGISTER_PASS(fc_fuse_pass, paddle::framework::ir::FCFusePass)
.RequirePassAttr("use_gpu"); .RequirePassAttr("use_gpu");
REGISTER_PASS_CAPABILITY(fc_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("mul", 0)
.EQ("elementwise_add", 0)
.EQ("relu", 0)
.EQ("fc", 0));
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -125,7 +126,6 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -125,7 +126,6 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto* x_n = subgraph.at(x); auto* x_n = subgraph.at(x);
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, gru_pattern); GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, gru_pattern);
GET_IR_NODE_FROM_SUBGRAPH(gru, gru, gru_pattern); GET_IR_NODE_FROM_SUBGRAPH(gru, gru, gru_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, gru_pattern); GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, gru_pattern);
...@@ -136,10 +136,17 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -136,10 +136,17 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
gru_pattern); gru_pattern);
GET_IR_NODE_FROM_SUBGRAPH(BatchHidden, BatchHidden, gru_pattern); GET_IR_NODE_FROM_SUBGRAPH(BatchHidden, BatchHidden, gru_pattern);
// TODO(wilber): Support origin_mode=True.
if (gru->Op()->GetAttrIfExists<bool>("origin_mode") == true) {
LOG(INFO) << "fc_gru_fuse_pass not supported when origin_mode=True.";
return;
}
if (with_fc_bias) { if (with_fc_bias) {
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern);
gru_creater(gru, x_n, w, Weight, Bias, Hidden, fc_bias); gru_creater(gru, x_n, w, Weight, Bias, Hidden, fc_bias);
// Remove unneeded nodes. // Remove unneeded nodes.
...@@ -188,3 +195,16 @@ void FCGRUFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -188,3 +195,16 @@ void FCGRUFusePass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(mul_gru_fuse_pass, paddle::framework::ir::MulGRUFusePass); REGISTER_PASS(mul_gru_fuse_pass, paddle::framework::ir::MulGRUFusePass);
REGISTER_PASS(fc_gru_fuse_pass, paddle::framework::ir::FCGRUFusePass); REGISTER_PASS(fc_gru_fuse_pass, paddle::framework::ir::FCGRUFusePass);
REGISTER_PASS_CAPABILITY(mul_gru_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("mul", 0)
.EQ("gru", 0)
.EQ("fusion_gru", 0));
REGISTER_PASS_CAPABILITY(fc_gru_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("mul", 0)
.EQ("elementwise_add", 0)
.EQ("gru", 0)
.EQ("fusion_gru", 0));
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -196,3 +197,17 @@ void FCLstmFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -196,3 +197,17 @@ void FCLstmFusePass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(mul_lstm_fuse_pass, paddle::framework::ir::MulLstmFusePass); REGISTER_PASS(mul_lstm_fuse_pass, paddle::framework::ir::MulLstmFusePass);
REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCLstmFusePass); REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCLstmFusePass);
REGISTER_PASS_CAPABILITY(fc_lstm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("mul", 0)
.EQ("elementwise_add", 0)
.EQ("lstm", 0)
.EQ("fusion_lstm", 0));
REGISTER_PASS_CAPABILITY(mul_lstm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("mul", 0)
.EQ("lstm", 0)
.EQ("fusion_lstm", 0));
...@@ -1892,6 +1892,82 @@ PDNode *patterns::QuantizePlacement::operator()( ...@@ -1892,6 +1892,82 @@ PDNode *patterns::QuantizePlacement::operator()(
return op; return op;
} }
PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>();
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types);
return op;
}
PDNode *patterns::OrphanedBfloat16::operator()() {
auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
prev_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"float32";
});
auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput();
auto *op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
auto *op_out = pattern->NewNode(op_out_repr())->AsOutput();
auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op();
next_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"float32";
});
prev_op->LinksTo({prev_out});
op->LinksFrom({prev_out}).LinksTo({op_out});
next_op->LinksFrom({op_out});
return next_op;
}
PDNode *patterns::LastBfloat16Ops::operator()() {
auto *op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
auto *op_out = pattern->NewNode(op_out_repr())->AsOutput();
auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op();
next_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") !=
"bfloat16";
});
op->LinksTo({op_out});
next_op->LinksFrom({op_out});
return next_op;
}
PDNode *patterns::FirstBfloat16Ops::operator()() {
auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
prev_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") !=
"bfloat16";
});
auto *op_in = pattern->NewNode(op_in_repr())->AsOutput();
auto *op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
prev_op->LinksTo({op_in});
op->LinksFrom({op_in});
return op;
}
PDNode *patterns::MKLDNNInPlace::operator()() { PDNode *patterns::MKLDNNInPlace::operator()() {
const std::unordered_set<std::string> &supported_op_types = { const std::unordered_set<std::string> &supported_op_types = {
"abs", "abs",
......
...@@ -1129,6 +1129,47 @@ struct QuantizePlacement : public PatternBase { ...@@ -1129,6 +1129,47 @@ struct QuantizePlacement : public PatternBase {
PATTERN_DECL_NODE(op); PATTERN_DECL_NODE(op);
}; };
struct Bfloat16Placement : public PatternBase {
Bfloat16Placement(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "bfloat16_placement") {}
PDNode* operator()(
const std::unordered_set<std::string>& bfloat16_enabled_op_types);
PATTERN_DECL_NODE(op);
};
struct OrphanedBfloat16 : public PatternBase {
OrphanedBfloat16(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "orphaned_bfloat16") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(prev_out);
PATTERN_DECL_NODE(op);
PATTERN_DECL_NODE(op_out);
PATTERN_DECL_NODE(next_op);
};
struct LastBfloat16Ops : public PatternBase {
LastBfloat16Ops(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "last_bfloat16_ops") {}
PDNode* operator()();
PATTERN_DECL_NODE(op);
PATTERN_DECL_NODE(op_out);
PATTERN_DECL_NODE(next_op);
};
struct FirstBfloat16Ops : public PatternBase {
FirstBfloat16Ops(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "first_bfloat16_ops") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(op_in);
PATTERN_DECL_NODE(op);
};
// Pattern used for enforcing inplace computation for in-place computation // Pattern used for enforcing inplace computation for in-place computation
// supporting DNNL ops. softmax, batch_norm and layer_norm // supporting DNNL ops. softmax, batch_norm and layer_norm
struct MKLDNNInPlace : public PatternBase { struct MKLDNNInPlace : public PatternBase {
......
...@@ -13,4 +13,6 @@ cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handl ...@@ -13,4 +13,6 @@ cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handl
cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass) cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass)
cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass) cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass)
cc_library(inplace_addto_op_pass SRCS inplace_addto_op_pass.cc DEPS memory_reuse_pass)
cc_test(test_reference_count_pass_last_lived_ops SRCS test_reference_count_pass_last_lived_ops.cc DEPS parallel_executor elementwise_mul_op elementwise_add_op scale_op) cc_test(test_reference_count_pass_last_lived_ops SRCS test_reference_count_pass_last_lived_ops.cc DEPS parallel_executor elementwise_mul_op elementwise_add_op scale_op)
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h" #include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h"
...@@ -141,11 +142,12 @@ void BufferSharedInplaceOpPass::Run(Graph *graph) const { ...@@ -141,11 +142,12 @@ void BufferSharedInplaceOpPass::Run(Graph *graph) const {
VLOG(4) << "Inplace performed in op " << op_type << ": " VLOG(4) << "Inplace performed in op " << op_type << ": "
<< in_var_handle_ptr->Name() << " -> " << in_var_handle_ptr->Name() << " -> "
<< out_var_handle_ptr->Name() << out_var_handle_ptr->Name()
<< ". Debug String is: " << op->GetOp()->DebugString(); << ". Debug String is: " << op->GetOp()->DebugString()
<< ". ReuseType: " << ReuseType();
} else { } else {
VLOG(3) << "Inplace failed in op " << op_type << ": " VLOG(3) << "Inplace failed in op " << op_type << ": "
<< in_var_handle_ptr->Name() << " -> " << in_var_handle_ptr->Name() << " -> "
<< out_var_handle_ptr->Name(); << out_var_handle_ptr->Name() << ". ReuseType: " << ReuseType();
} }
} }
} }
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class InplaceAddToOpPass : public MemoryReusePass {
protected:
std::string ReuseType() const override { return "inplace_addto"; }
void Run(Graph *graph) const override;
private:
// 1. Add last living op of in_var, add any last living op of out_var
// 2. Set reference count of in_var to be 2
void UpdateLastLiveOpOfVar(details::ComputationOpHandle *op,
details::VarHandle *in_var,
details::VarHandle *out_var) const override {
size_t scope_idx = op->GetScopeIdx();
auto *last_live_ops_of_vars_ =
&Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
auto *var_infos_ = &(Get<MemOptVarInfoMapList>(kMemOptVarInfoMapList));
auto out_var_op_iter =
(*last_live_ops_of_vars_)[scope_idx].find(out_var->Name());
// In Reduce mode, some output variable(gradient of parameter) does not have
// last live ops
details::ComputationOpHandle *last_live_op_of_in_var = nullptr;
if (out_var_op_iter == (*last_live_ops_of_vars_)[scope_idx].end()) {
last_live_op_of_in_var = op;
} else {
PADDLE_ENFORCE_EQ(
out_var_op_iter->second.ops().empty(), false,
platform::errors::InvalidArgument(
"Var(%s)'s last live op should not empty.", out_var->Name()));
last_live_op_of_in_var = *(out_var_op_iter->second.ops().begin());
}
auto *last_live_ops_of_in_var =
(*last_live_ops_of_vars_)[scope_idx][in_var->Name()].mutable_ops();
// last_live_ops_of_in_var->clear();
last_live_ops_of_in_var->insert(last_live_op_of_in_var);
auto in_var_info_iter = (*var_infos_)[scope_idx].find(in_var->Name());
PADDLE_ENFORCE_NE(
in_var_info_iter, (*var_infos_)[scope_idx].end(),
platform::errors::NotFound("Cannot find variable %s.", in_var->Name()));
in_var_info_iter->second->SetRefCnt(2); // before inplace, it is 1
}
};
void InplaceAddToOpPass::Run(Graph *graph) const {
const auto &last_live_ops =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
bool use_cuda = Get<bool>(kUseCuda);
// Currently, only perform InplaceAddToOpPass on cuda place
if (!use_cuda) {
return;
}
// Step 1: Build a reverse map of last_live_ops
// i.e.: op -> vars
std::unordered_map<details::ComputationOpHandle *,
std::unordered_map<std::string, ir::Node *>>
candidate_ops;
for (auto &each_scope_ops : last_live_ops) {
for (auto &pair : each_scope_ops) {
// If variable has more than 1 last lived ops, this variable cannot
// be inplaced.
if (pair.second.ops().size() != 1) {
continue;
}
auto *op = *(pair.second.ops().begin());
const std::string &op_type = op->GetOp()->Type();
const framework::OpDesc *op_desc = op->Node()->Op();
PADDLE_ENFORCE_NOT_NULL(
op_desc, platform::errors::NotFound("Op(%s) can not find opdesc.",
op->Name()));
// only grad op should be processed.
if (op_type != "grad_add") {
continue;
}
const std::string &var_name = pair.first;
auto in_nodes = this->FindNodesByName(var_name, op->Node()->inputs);
if (in_nodes.size() == 1) {
candidate_ops[op][var_name] = *in_nodes.begin();
}
VLOG(4) << "Find op " << op_type << " with input(" << var_name
<< ") that can do inplace add to";
}
}
// Step 2: Check which vars can be inplaced indeed
for (auto &op_vars_pair : candidate_ops) {
auto *op = op_vars_pair.first;
// The original gradient accumulation is g = sum(g_0, g_1,..., g_n), and it
// could be changed as follws if inplace addto is enabled:
// g_sum_0 = g_0
// g_sum_1 = grad_add(g_sum_0, g_1)
// g_sum_2 = grad_add(g_sum_1, g_2)
// ...
// g_sum_n = grad_add(g_sum_n-1, g_n)
// here we will add inplace for each grad_add, for example, for the first
// grad_add, g_sum_0 -> g1, g_sum_1 -> g1, and set grad_add as skipped.
const std::string &op_type = op->GetOp()->Type();
PADDLE_ENFORCE_EQ(op->Node()->inputs.size(), 2,
platform::errors::InvalidArgument(
"The size of inputs of %s should be 2, but got %d",
op_type, op->Node()->inputs.size()));
PADDLE_ENFORCE_EQ(op->Node()->outputs.size(), 1,
platform::errors::InvalidArgument(
"The size of outputs of %s should be 1, but got %d",
op_type, op->Node()->outputs.size()));
auto *left_var_ptr = dynamic_cast<details::VarHandle *>(
&(op->Node()->inputs[0]->Wrapper<details::VarHandleBase>()));
auto *right_var_ptr = dynamic_cast<details::VarHandle *>(
&(op->Node()->inputs[1]->Wrapper<details::VarHandleBase>()));
auto *out_var_ptr = dynamic_cast<details::VarHandle *>(
&(op->Node()->outputs[0]->Wrapper<details::VarHandleBase>()));
if (left_var_ptr == nullptr || right_var_ptr == nullptr ||
out_var_ptr == nullptr) {
continue;
}
// auto *left_generated_op = dynamic_cast<details::ComputationOpHandle *>(
// left_var_ptr->GeneratedOp());
auto *right_generated_op = dynamic_cast<details::ComputationOpHandle *>(
right_var_ptr->GeneratedOp());
auto *out_generated_op = dynamic_cast<details::ComputationOpHandle *>(
out_var_ptr->GeneratedOp());
// NOTE(zhiqiu): currently, only conv2d_grad supports addto strategy
if (right_generated_op->Name() != "conv2d_grad") {
continue;
}
// NOTE(zhiqiu): Normally, if we inplace a->b, we should let a generated
// before b. However, in the situation of inplace addto, we do not care
// the order, since a+b is equal to b+a. Is there any exception for that?
// AddDependencyVar(right_generated_op, left_generated_op);
// no need, as discussed above.
// step (a): inplace right_var->left_var of grad_add
this->AddReuseVar(right_generated_op, left_var_ptr, right_var_ptr);
UpdateLastLiveOpOfVar(right_generated_op, left_var_ptr, right_var_ptr);
VLOG(4) << "Inplace performed in op " << right_generated_op->GetOp()->Type()
<< ": " << left_var_ptr->Name() << " -> " << right_var_ptr->Name()
<< ". Debug String is: "
<< right_generated_op->GetOp()->DebugString()
<< ". ReuseType: " << ReuseType();
// step (b): inplace out -> right_var of grad_add
this->AddReuseVar(out_generated_op, right_var_ptr, out_var_ptr, true);
VLOG(4) << "Inplace performed in op " << op_type << ": "
<< left_var_ptr->Name() << " -> " << out_var_ptr->Name()
<< ". Debug String is: " << op->GetOp()->DebugString()
<< ". ReuseType: " << ReuseType();
// step (c): make right_var cannot inplace afterwards. canbe done
// aotomatically since CollectReusedVars is called before any reuse.
// step (d): make right_var's generated op use addto
right_generated_op->GetOp()->SetAttr("use_addto", true);
// step (e): make grad_add skip running
op->SetSkipRunning(true);
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(inplace_addto_op_pass, paddle::framework::ir::InplaceAddToOpPass)
.RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList)
.RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::ir::kUseCuda);
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h"
#include <functional> #include <functional>
#include <map> #include <map>
#include <string> #include <string>
...@@ -73,6 +74,7 @@ bool MemoryReusePass::TryReuseVar(details::VarHandle *in_var, ...@@ -73,6 +74,7 @@ bool MemoryReusePass::TryReuseVar(details::VarHandle *in_var,
out_var->Name())); out_var->Name()));
if (IsVarPairReusable(*in_var, *out_var)) { if (IsVarPairReusable(*in_var, *out_var)) {
AddReuseVar(op, in_var, out_var); AddReuseVar(op, in_var, out_var);
UpdateLastLiveOpOfVar(op, in_var, out_var);
return true; return true;
} else { } else {
return false; return false;
...@@ -324,7 +326,8 @@ bool MemoryReusePass::IsVarPairReusable( ...@@ -324,7 +326,8 @@ bool MemoryReusePass::IsVarPairReusable(
void MemoryReusePass::AddReuseVar(details::ComputationOpHandle *op, void MemoryReusePass::AddReuseVar(details::ComputationOpHandle *op,
details::VarHandle *in_var, details::VarHandle *in_var,
details::VarHandle *out_var) const { details::VarHandle *out_var,
bool share_dims) const {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
(*var_infos_)[op->GetScopeIdx()].count(in_var->Name()), 0, (*var_infos_)[op->GetScopeIdx()].count(in_var->Name()), 0,
platform::errors::NotFound("Var(%s) does not in mem opt var infos.", platform::errors::NotFound("Var(%s) does not in mem opt var infos.",
...@@ -344,13 +347,15 @@ void MemoryReusePass::AddReuseVar(details::ComputationOpHandle *op, ...@@ -344,13 +347,15 @@ void MemoryReusePass::AddReuseVar(details::ComputationOpHandle *op,
share_buffer_op->AddInput(in_var); share_buffer_op->AddInput(in_var);
} }
if (share_dims) {
share_buffer_op->SetShareDims(true);
}
share_buffer_op->AddReuseVarPair( share_buffer_op->AddReuseVarPair(
(*var_infos_)[op->GetScopeIdx()].at(in_var->Name()).get(), (*var_infos_)[op->GetScopeIdx()].at(in_var->Name()).get(),
out_var->Name()); out_var->Name());
reused_in_var_names_[op->GetScopeIdx()].insert(in_var->Name()); reused_in_var_names_[op->GetScopeIdx()].insert(in_var->Name());
reused_out_var_names_[op->GetScopeIdx()].insert(out_var->Name()); reused_out_var_names_[op->GetScopeIdx()].insert(out_var->Name());
UpdateLastLiveOpOfVar(op, in_var, out_var);
} }
// 1. Set last living op of in_var to be any last living op of out_var // 1. Set last living op of in_var to be any last living op of out_var
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h" #include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h"
...@@ -92,6 +93,12 @@ class MemoryReusePass : public Pass { ...@@ -92,6 +93,12 @@ class MemoryReusePass : public Pass {
int64_t GetMemorySize(const details::VarHandle &var) const; int64_t GetMemorySize(const details::VarHandle &var) const;
void AddReuseVar(details::ComputationOpHandle *op, details::VarHandle *in_var,
details::VarHandle *out_var, bool share_dims = false) const;
virtual void UpdateLastLiveOpOfVar(details::ComputationOpHandle *op,
details::VarHandle *in_var,
details::VarHandle *out_var) const;
private: private:
VarDesc *GetVarDesc(const details::VarHandle &var) const; VarDesc *GetVarDesc(const details::VarHandle &var) const;
...@@ -109,13 +116,6 @@ class MemoryReusePass : public Pass { ...@@ -109,13 +116,6 @@ class MemoryReusePass : public Pass {
void CollectReusedVars() const; void CollectReusedVars() const;
void AddReuseVar(details::ComputationOpHandle *op, details::VarHandle *in_var,
details::VarHandle *out_var) const;
void UpdateLastLiveOpOfVar(details::ComputationOpHandle *op,
details::VarHandle *in_var,
details::VarHandle *out_var) const;
private: private:
mutable Graph *graph_; mutable Graph *graph_;
mutable bool use_cuda_; mutable bool use_cuda_;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -84,6 +85,19 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -84,6 +85,19 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "do not perform " + type() + "+bias fuse"; VLOG(3) << "do not perform " + type() + "+bias fuse";
return; return;
} }
if (conv->Op()->HasAttr("dilations")) {
auto dilations =
BOOST_GET_CONST(std::vector<int>, conv->Op()->GetAttr("dilations"));
for (const auto& d : dilations) {
if (d != 1) {
LOG(WARNING)
<< "dilation conv not supported in MKLDNN, fuse not apply "
<< "and set conv attribute use_mkldnn = false";
conv->Op()->SetAttr("use_mkldnn", false);
return;
}
}
}
auto* eltwise_bias_tensor = auto* eltwise_bias_tensor =
scope->FindVar(eltwise_bias->Name())->GetMutable<LoDTensor>(); scope->FindVar(eltwise_bias->Name())->GetMutable<LoDTensor>();
...@@ -151,3 +165,8 @@ REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass, ...@@ -151,3 +165,8 @@ REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DTransposeBiasFusePass); paddle::framework::ir::Conv2DTransposeBiasFusePass);
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass, REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv3DBiasFusePass); paddle::framework::ir::Conv3DBiasFusePass);
REGISTER_PASS_CAPABILITY(conv_bias_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("elementwise_add", 0));
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
namespace paddle { namespace paddle {
...@@ -149,6 +150,12 @@ TEST(ConvBiasFusePass, conv2d_transpose) { ...@@ -149,6 +150,12 @@ TEST(ConvBiasFusePass, conv2d_transpose) {
ASSERT_EQ(pass.type(), std::string("conv2d_transpose")); ASSERT_EQ(pass.type(), std::string("conv2d_transpose"));
} }
TEST(ConvBiasFusePass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("conv_bias_mkldnn_fuse_pass"));
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <memory> #include <memory>
#include <tuple> #include <tuple>
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -341,3 +342,8 @@ void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -341,3 +342,8 @@ void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass, REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
paddle::framework::ir::ResidualConnectionMKLDNNFusePass); paddle::framework::ir::ResidualConnectionMKLDNNFusePass);
REGISTER_PASS_CAPABILITY(conv_elementwise_add_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("elementwise_add", 0));
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -267,6 +268,12 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) { ...@@ -267,6 +268,12 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
AssertOpsCount(graph, 2, 1); AssertOpsCount(graph, 2, 1);
} }
TEST(ConvElementwiseAddMKLDNNFusePass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("conv_elementwise_add_mkldnn_fuse_pass"));
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void UnlinkNodes(ir::Node* a, ir::Node* b) {
a->outputs.erase(std::remove(a->outputs.begin(), a->outputs.end(), b),
a->outputs.end());
b->inputs.erase(std::remove(b->inputs.begin(), b->inputs.end(), a),
b->inputs.end());
}
void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
"first_bfloat16_ops"};
bfloat16_ops();
int quantize_counter = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
if (op->Op()->Type() != "conv2d" && prev_op->Op()->Type() != "quantize") {
VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc);
// create a quantize op node
OpDesc q_desc;
q_desc.SetType("quantize");
q_desc.SetInput("Input", std::vector<std::string>({op_in->Name()}));
q_desc.SetOutput("Output",
std::vector<std::string>({quantize_out_node->Name()}));
q_desc.SetAttr("Scale", 1.f);
q_desc.SetAttr("bfloat16", true);
q_desc.SetAttr("output_format", Has("data_layout")
? Get<std::string>("data_layout")
: "NCHW");
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
std::string op_input_name;
for (auto name : op->Op()->InputNames()) {
for (auto input_name : op->Op()->Input(name)) {
if (input_name == op_in->Name()) op_input_name = name;
}
}
PADDLE_ENFORCE_NE(
op_input_name.empty(), true,
platform::errors::NotFound(
"Operator before operator should have input as op output"));
op->Op()->SetInput(op_input_name,
std::vector<std::string>({quantize_out_node->Name()}));
UnlinkNodes(op_in, op);
IR_NODE_LINK_TO(op_in, quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_node);
IR_NODE_LINK_TO(quantize_out_node, op);
quantize_counter++;
}
};
gpd(graph, handler);
PrettyLogDetail("--- added %d quantize op before bfloat16 op",
quantize_counter);
}
void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::LastBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
"last_bfloat16_ops"};
bfloat16_ops();
int force_fp32_counter = 0, dequantize_counter = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, bfloat16_ops);
if ((op->Op()->HasAttr("force_fp32_output") ||
op->Op()->HasProtoAttr("force_fp32_output")) &&
!op->Op()->GetAttrIfExists<bool>("fuse_residual_connection")) {
op->Op()->SetAttr("force_fp32_output", true);
force_fp32_counter++;
} else if (op->Op()->Type() != "prior_box") {
// Create dequantize input variable
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);
// create a dequantize op node for output.
OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
deq_desc.SetAttr("Scale", 1.0f);
auto dequantize_op = g->CreateOpNode(&deq_desc);
std::string op_output_name;
for (auto name : op->Op()->OutputNames()) {
for (auto output_name : op->Op()->Output(name)) {
if (output_name == op_out->Name()) op_output_name = name;
}
}
PADDLE_ENFORCE_NE(
op_output_name.empty(), true,
platform::errors::NotFound(
"Operator after operator should have input as op output"));
op->Op()->SetOutput(op_output_name, std::vector<std::string>(
{dequantize_in_node->Name()}));
UnlinkNodes(op, op_out);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, op_out);
dequantize_counter++;
}
};
gpd(graph, handler);
PrettyLogDetail("--- added %d dequantize op and used %d force_fp32_output",
dequantize_counter, force_fp32_counter);
}
void CPUBFloat16Pass::ApplyImpl(ir::Graph* graph) const {
SetInputDataType(graph);
SetOutputDataType(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(cpu_bfloat16_pass, paddle::framework::ir::CPUBFloat16Pass);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class CPUBFloat16Pass : public Pass {
protected:
void SetInputDataType(ir::Graph* graph) const;
void SetOutputDataType(ir::Graph* graph) const;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h"
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn,
const std::string& mkldnn_data_type = "float32",
const bool force_fp32_output = false) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name);
if (type == "conv2d") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
op->SetAttr("force_fp32_output", force_fp32_output);
} else if (type == "pool2d" || type == "transpose2" || type == "reshape2" ||
type == "dropout") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} else if (type == "fc") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} else if (type == "concat") {
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} else if (type == "matmul" || type == "elementwise_add") {
op->SetInput("X", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
}
}
void PreparePass(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog,
const std::initializer_list<std::string> variable_names,
int* original_nodes_num, int* current_nodes_num) {
auto pass = PassRegistry::Instance().Get("cpu_bfloat16_pass");
graph->reset(pass->Apply(graph->release()));
*original_nodes_num = (*graph)->Nodes().size();
(*graph).reset(pass->Apply((*graph).release()));
*current_nodes_num = (*graph)->Nodes().size();
}
static const std::initializer_list<std::string> variable_names{
"z", "a", "b", "c", "d", "e", "f", "g", "h", "i"};
ProgramDesc BuildProgramDesc(bool use_mkldnn) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "dropout", "Dropout1", {"z"}, {"a"}, use_mkldnn, "float32");
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, "bfloat16");
SetOp(&prog, "pool2d", "Pool1", {"b"}, {"c"}, use_mkldnn, "bfloat16");
SetOp(&prog, "conv2d", "Conv1", {"c"}, {"d"}, use_mkldnn, "bfloat16");
SetOp(&prog, "dropout", "Dropout2", {"d"}, {"e"}, use_mkldnn, "float32");
SetOp(&prog, "transpose2", "Transpose1", {"e"}, {"f"}, use_mkldnn,
"bfloat16");
SetOp(&prog, "reshape2", "Reshape1", {"f"}, {"g"}, use_mkldnn, "bfloat16");
SetOp(&prog, "concat", "Concat1", {"g"}, {"h"}, use_mkldnn, "bfloat16");
SetOp(&prog, "dropout", "Dropout3", {"h"}, {"i"}, use_mkldnn, "float32");
return prog;
}
void MainTest(const ProgramDesc& prog, int conv_count, int pool_count,
int transpose_count, int quant_count, int dequant_count,
int added_nodes_count) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num, current_nodes_num;
PreparePass(&graph, prog, variable_names, &original_nodes_num,
&current_nodes_num);
int quantize_nodes_count = 0;
int dequantize_nodes_count = 0;
int conv2d_nodes_count = 0;
int pool2d_nodes_count = 0;
int transpose2_nodes_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->Type() == "conv2d") {
conv2d_nodes_count++;
} else if (op->Type() == "pool2d") {
pool2d_nodes_count++;
} else if (op->Type() == "transpose2") {
transpose2_nodes_count++;
} else if (op->Type() == "quantize") {
quantize_nodes_count++;
} else if (op->Type() == "dequantize") {
dequantize_nodes_count++;
}
}
}
EXPECT_EQ(conv2d_nodes_count, conv_count);
EXPECT_EQ(pool2d_nodes_count, pool_count);
EXPECT_EQ(transpose2_nodes_count, transpose_count);
EXPECT_EQ(quantize_nodes_count, quant_count);
EXPECT_EQ(dequantize_nodes_count, dequant_count);
EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num);
}
TEST(CpuQuantizePass, quantize) {
bool use_mkldnn = true;
// 1 quantize + 1 dequantize
int added_nodes = 2;
MainTest(BuildProgramDesc(use_mkldnn), 2, 1, 1, 1, 2, added_nodes);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(cpu_bfloat16_pass);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h"
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void CPUBfloat16PlacementPass::SetMkldnnDataType(
ir::Graph* graph, int* bfloat16_operators) const {
const auto& op_types_list =
Get<std::unordered_set<std::string>>("bfloat16_enabled_op_types");
// set mkldnn_data_type to bfloat16 to all operators that are in
// bfloat16_enabled_op_types vector or they are included to Bfloat16Placement
// pattern
GraphPatternDetector gpd;
patterns::Bfloat16Placement bfloat16_placement_pattern{gpd.mutable_pattern(),
"bfloat16_placement"};
bfloat16_placement_pattern(op_types_list);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_placement_pattern);
if ((op->Op()->HasAttr("mkldnn_data_type") ||
op->Op()->HasProtoAttr("mkldnn_data_type")) &&
!platform::HasOpINT8DataType(op->Op())) {
op->Op()->SetAttr("mkldnn_data_type", std::string("bfloat16"));
(*bfloat16_operators)++;
}
};
gpd(graph, handler);
}
void CPUBfloat16PlacementPass::RemoveOrhanedOperators(
ir::Graph* graph, int* bfloat16_operators) const {
// find orphaned bfloat16 operator that is between two float32 operators
// revert mkldnn_data_type attr to float32
GraphPatternDetector gpd;
patterns::OrphanedBfloat16 orphaned_bfloat16_pattern{gpd.mutable_pattern(),
"orphaned_bfloat16"};
orphaned_bfloat16_pattern();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, orphaned_bfloat16_pattern);
op->Op()->SetAttr("mkldnn_data_type", std::string("float32"));
bfloat16_operators--;
};
gpd(graph, handler);
}
void CPUBfloat16PlacementPass::ApplyImpl(ir::Graph* graph) const {
int bfloat16_operators = 0;
SetMkldnnDataType(graph, &bfloat16_operators);
RemoveOrhanedOperators(graph, &bfloat16_operators);
PrettyLogDetail("--- marked %d operators to bfloat16 ",
bfloat16_operators);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(cpu_bfloat16_placement_pass,
paddle::framework::ir::CPUBfloat16PlacementPass)
// a vector of operator type names with bfloat16 support ("conv2d" etc.)
// the second param is the default value for this vector
.DefaultPassAttr("bfloat16_enabled_op_types",
new std::unordered_set<std::string>());
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Specifies which operators should be run on bfloat16.
*/
class CPUBfloat16PlacementPass : public Pass {
protected:
void SetMkldnnDataType(ir::Graph* graph, int* bfloat16_operators) const;
void RemoveOrhanedOperators(ir::Graph* graph, int* bfloat16_operators) const;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::string& mkldnn_data_type = "float32") {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
if (type == "conv2d") {
op->SetAttr("name", name);
op->SetInput("Input", {inputs[0]});
} else if (type == "relu") {
op->SetInput("X", inputs);
} else if (type == "concat") {
op->SetAttr("axis", 1);
op->SetInput("X", {inputs[0], inputs[1]});
} else if (type == "pool2d") {
op->SetInput("X", {inputs[0]});
} else {
FAIL() << "Unexpected operator type.";
}
op->SetOutput("Out", {outputs[0]});
}
// operator mkldnn_data_type
// ---------------------------------------
// (a,b)->concat->c float32
// c->conv->f float32
// f->relu->g float32
// g->pool->h float32
// h->conv->k float32
// k->pool->l float32
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "f", "g", "h", "k", "l"})) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "concat", "concat1", {"a", "b"}, {"c"});
SetOp(&prog, "conv2d", "conv1", {"c"}, {"f"});
SetOp(&prog, "relu", "relu1", {"f"}, {"g"});
SetOp(&prog, "pool2d", "pool1", {"g"}, {"h"});
SetOp(&prog, "conv2d", "conv2", {"h"}, {"k"});
SetOp(&prog, "pool2d", "pool2", {"k"}, {"l"});
return prog;
}
void MainTest(std::initializer_list<std::string> bfloat16_enabled_op_types,
unsigned expected_bfloat16_data_type_count) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass");
pass->Set("bfloat16_enabled_op_types",
new std::unordered_set<std::string>(bfloat16_enabled_op_types));
graph.reset(pass->Apply(graph.release()));
unsigned bfloat16_data_type_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
if (platform::HasOpBFLOAT16DataType(node->Op())) {
++bfloat16_data_type_count;
}
}
}
EXPECT_EQ(bfloat16_data_type_count, expected_bfloat16_data_type_count);
}
void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass");
graph.reset(pass->Apply(graph.release()));
unsigned bfloat16_data_type_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
if (platform::HasOpBFLOAT16DataType(node->Op())) {
++bfloat16_data_type_count;
}
}
}
EXPECT_EQ(bfloat16_data_type_count, expected_bfloat16_data_type_count);
}
TEST(Bfloat16PlacementPass, enable_all) {
MainTest({"conv2d", "pool2d", "relu", "concat"}, 6);
}
TEST(Bfloat16PlacementPass, enabled_conv_and_pool) {
// 2 conv2d + 2 pool2 - 1 orphaned conv2d
MainTest({"conv2d", "pool2d"}, 3);
}
TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(0); }
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(cpu_bfloat16_placement_pass);
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h" #include "paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -57,3 +58,7 @@ void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const { ...@@ -57,3 +58,7 @@ void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(depthwise_conv_mkldnn_pass, REGISTER_PASS(depthwise_conv_mkldnn_pass,
paddle::framework::ir::DepthwiseConvMKLDNNPass); paddle::framework::ir::DepthwiseConvMKLDNNPass);
REGISTER_PASS_CAPABILITY(depthwise_conv_mkldnn_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"depthwise_conv2d", 0));
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
...@@ -70,6 +72,12 @@ ProgramDesc BuildProgramDesc() { ...@@ -70,6 +72,12 @@ ProgramDesc BuildProgramDesc() {
return prog; return prog;
} }
TEST(DepthwiseConvMKLDNNPass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("depthwise_conv_mkldnn_pass"));
}
TEST(DepthwiseConvMKLDNNPass, basic) { TEST(DepthwiseConvMKLDNNPass, basic) {
auto prog = BuildProgramDesc(); auto prog = BuildProgramDesc();
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
namespace paddle { namespace paddle {
...@@ -615,6 +616,16 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -615,6 +616,16 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern); multihead_pattern);
// If weights or biases in qkv's fc are shared by multiple multihead_matmul
// patterns, we do not support this kind of fusion, this pass will not take
// effect.
bool is_fc_params_shared =
mul0_w->outputs.size() > 1 || mul1_w->outputs.size() > 1 ||
mul2_w->outputs.size() > 1 || eltadd0_b->outputs.size() > 1 ||
eltadd1_b->outputs.size() > 1 || eltadd2_b->outputs.size() > 1;
if (is_fc_params_shared) {
return;
}
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w, fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
reshape2_0, reshape2_qkv_out, scale, scale_out); reshape2_0, reshape2_qkv_out, scale, scale_out);
...@@ -697,3 +708,13 @@ REGISTER_PASS(multihead_matmul_fuse_pass, ...@@ -697,3 +708,13 @@ REGISTER_PASS(multihead_matmul_fuse_pass,
REGISTER_PASS(multihead_matmul_fuse_pass_v2, REGISTER_PASS(multihead_matmul_fuse_pass_v2,
paddle::framework::ir::MultiHeadMatmulV2FusePass); paddle::framework::ir::MultiHeadMatmulV2FusePass);
REGISTER_PASS_CAPABILITY(multihead_matmul_fuse_pass_v2)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("mul", 0)
.EQ("elementwise_add", 0)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.EQ("matmul", 0)
.EQ("softmax", 0));
...@@ -12,6 +12,7 @@ limitations under the License. */ ...@@ -12,6 +12,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h" // NOLINT #include "paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h" // NOLINT
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -133,6 +134,12 @@ TEST(MultiHeadMatmulFusePass, basic) { ...@@ -133,6 +134,12 @@ TEST(MultiHeadMatmulFusePass, basic) {
num_fused_nodes_after)); num_fused_nodes_after));
} }
TEST(MultiHeadMatmulFusePass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("multihead_matmul_fuse_pass_v2"));
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#define MAX_NUM_FC 10 #define MAX_NUM_FC 10
...@@ -174,6 +175,11 @@ void BuildRepeatedFCReluPattern(PDPattern* pattern, ...@@ -174,6 +175,11 @@ void BuildRepeatedFCReluPattern(PDPattern* pattern,
if (x->outputs.size() <= 0 || x->inputs.size() <= 0U) { if (x->outputs.size() <= 0 || x->inputs.size() <= 0U) {
return false; return false;
} }
if (x->IsVar() && x->Var() && x->Var()->GetShape().size() > 2) {
VLOG(3) << "repeated fc relu only supports input dims = 2, so it "
"is not applied.";
return false;
}
int fc_idx = FindFCIdx(x); int fc_idx = FindFCIdx(x);
if (fc_idx < 0) { if (fc_idx < 0) {
return false; return false;
...@@ -384,3 +390,8 @@ void RepeatedFCReluFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -384,3 +390,8 @@ void RepeatedFCReluFusePass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(repeated_fc_relu_fuse_pass, REGISTER_PASS(repeated_fc_relu_fuse_pass,
paddle::framework::ir::RepeatedFCReluFusePass); paddle::framework::ir::RepeatedFCReluFusePass);
REGISTER_PASS_CAPABILITY(repeated_fc_relu_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("fc", 0)
.EQ("relu", 0));
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -98,3 +99,9 @@ void SeqConvEltAddReluFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -98,3 +99,9 @@ void SeqConvEltAddReluFusePass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(seqconv_eltadd_relu_fuse_pass, REGISTER_PASS(seqconv_eltadd_relu_fuse_pass,
paddle::framework::ir::SeqConvEltAddReluFusePass); paddle::framework::ir::SeqConvEltAddReluFusePass);
REGISTER_PASS_CAPABILITY(seqconv_eltadd_relu_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("sequence_conv", 0)
.EQ("elementwise_add", 0)
.EQ("relu", 0));
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/shuffle_channel_detect_pass.h" #include "paddle/fluid/framework/ir/shuffle_channel_detect_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -82,6 +83,9 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { ...@@ -82,6 +83,9 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
// Delete the unneeded nodes. // Delete the unneeded nodes.
GraphSafeRemoveNodes(graph, {reshape1_op, reshape1_out, transpose_op, GraphSafeRemoveNodes(graph, {reshape1_op, reshape1_out, transpose_op,
transpose_out, reshape2_op}); transpose_out, reshape2_op});
LOG_FIRST_N(WARNING, 1)
<< "There is fluid.layers.shuffle_channel API already, maybe you can "
"use it instead of (reshape + transpose + reshape)";
}; };
gpd(graph, handler); gpd(graph, handler);
...@@ -93,3 +97,8 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { ...@@ -93,3 +97,8 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(shuffle_channel_detect_pass, REGISTER_PASS(shuffle_channel_detect_pass,
paddle::framework::ir::ShuffleChannelDetectPass); paddle::framework::ir::ShuffleChannelDetectPass);
REGISTER_PASS_CAPABILITY(shuffle_channel_detect_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("reshape2", 0)
.EQ("transpose2", 0));
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -180,3 +181,8 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -180,3 +181,8 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
REGISTER_PASS(skip_layernorm_fuse_pass, REGISTER_PASS(skip_layernorm_fuse_pass,
paddle::framework::ir::SkipLayerNormFusePass); paddle::framework::ir::SkipLayerNormFusePass);
REGISTER_PASS_CAPABILITY(skip_layernorm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("elementwise_add", 0)
.EQ("layer_norm", 0));
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -54,6 +55,12 @@ TEST(SkipLayerNormFusePass, basic) { ...@@ -54,6 +55,12 @@ TEST(SkipLayerNormFusePass, basic) {
"The number of fusion nodes does not meet expectations after fuse")); "The number of fusion nodes does not meet expectations after fuse"));
} }
TEST(SkipLayerNormFusePass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("skip_layernorm_fuse_pass"));
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -77,7 +78,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, ...@@ -77,7 +78,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
}; };
auto is_fusion_input_var = [=](Node* x, const std::string& arg_name) { auto is_fusion_input_var = [=](Node* x, const std::string& arg_name) {
bool basic = var_is_op_input(x, "matmul", arg_name) && bool basic = (var_is_op_input(x, "matmul_v2", arg_name) ||
var_is_op_input(x, "matmul", arg_name)) &&
var_is_op_input(x, "square", "X"); var_is_op_input(x, "square", "X");
if (!basic) { if (!basic) {
return false; return false;
...@@ -88,7 +90,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, ...@@ -88,7 +90,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
} }
auto* squared_x = squared_x_op->outputs[0]; auto* squared_x = squared_x_op->outputs[0];
bool next_is_matmul_from_arg = bool next_is_matmul_from_arg =
var_is_op_input(squared_x, "matmul", arg_name) && (var_is_op_input(squared_x, "matmul_v2", arg_name) ||
var_is_op_input(squared_x, "matmul", arg_name)) &&
squared_x->outputs.size() == 1 && squared_x->outputs.size() == 1 &&
squared_x->outputs[0]->outputs.size() == 1; squared_x->outputs[0]->outputs.size() == 1;
if (!next_is_matmul_from_arg) { if (!next_is_matmul_from_arg) {
...@@ -103,7 +106,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, ...@@ -103,7 +106,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
auto is_fusion_first_mul_out = [=](Node* x) -> bool { auto is_fusion_first_mul_out = [=](Node* x) -> bool {
bool input_is_matmul_op = x && x->inputs.size() == 1 && bool input_is_matmul_op = x && x->inputs.size() == 1 &&
x->inputs[0]->IsOp() && x->inputs[0]->IsOp() &&
x->inputs[0]->Op()->Type() == "matmul"; (x->inputs[0]->Op()->Type() == "matmul_v2" ||
x->inputs[0]->Op()->Type() == "matmul");
if (!input_is_matmul_op) { if (!input_is_matmul_op) {
return false; return false;
} }
...@@ -167,7 +171,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, ...@@ -167,7 +171,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
auto* matmul_xy_op = pattern->NewNode( auto* matmul_xy_op = pattern->NewNode(
[=](Node* x) { [=](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "matmul" && return x && x->IsOp() && (x->Op()->Type() == "matmul_v2" ||
x->Op()->Type() == "matmul") &&
is_fusion_first_mul_out(x->outputs[0]); is_fusion_first_mul_out(x->outputs[0]);
}, },
name_scope + "/matmul_xy_op"); name_scope + "/matmul_xy_op");
...@@ -189,7 +194,9 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, ...@@ -189,7 +194,9 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
auto is_fusion_mat_squared_x_y_op_out = [=](Node* x) -> bool { auto is_fusion_mat_squared_x_y_op_out = [=](Node* x) -> bool {
bool basic = x && x->IsVar() && x->inputs.size() == 1 && bool basic = x && x->IsVar() && x->inputs.size() == 1 &&
x->inputs[0]->IsOp() && x->inputs[0]->Op()->Type() == "matmul"; x->inputs[0]->IsOp() &&
(x->inputs[0]->Op()->Type() == "matmul_v2" ||
x->inputs[0]->Op()->Type() == "matmul");
if (!basic) { if (!basic) {
return false; return false;
} }
...@@ -206,7 +213,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, ...@@ -206,7 +213,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
auto* matmul_squared_x_y_op = pattern->NewNode( auto* matmul_squared_x_y_op = pattern->NewNode(
[=](Node* x) { [=](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "matmul" && return x && x->IsOp() && (x->Op()->Type() == "matmul_v2" ||
x->Op()->Type() == "matmul") &&
is_fusion_mat_squared_x_y_op_out(x->outputs[0]); is_fusion_mat_squared_x_y_op_out(x->outputs[0]);
}, },
name_scope + "/matmul_squared_x_y_op"); name_scope + "/matmul_squared_x_y_op");
...@@ -378,3 +386,13 @@ void SquaredMatSubFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -378,3 +386,13 @@ void SquaredMatSubFusePass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(squared_mat_sub_fuse_pass, REGISTER_PASS(squared_mat_sub_fuse_pass,
paddle::framework::ir::SquaredMatSubFusePass); paddle::framework::ir::SquaredMatSubFusePass);
REGISTER_PASS_CAPABILITY(squared_mat_sub_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul", 0)
.EQ("matmul_v2", 0)
.EQ("square", 0)
.EQ("elementwise_mul", 0)
.EQ("elementwise_sub", 0)
.EQ("fill_constant", 0)
.EQ("fusion_squared_mat_sub", 0));
...@@ -24,7 +24,7 @@ namespace framework { ...@@ -24,7 +24,7 @@ namespace framework {
namespace ir { namespace ir {
/** /**
* Fuse ( (A.^2 * B.^2) - (A * B).^2 ) .* scalar * Fuse ( (A * B).^2 - (A.^2 * B.^2) ) .* scalar
*/ */
class SquaredMatSubFusePass : public FusePassBase { class SquaredMatSubFusePass : public FusePassBase {
public: public:
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h" #include "paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -145,3 +146,11 @@ void TransposeFlattenConcatFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -145,3 +146,11 @@ void TransposeFlattenConcatFusePass::ApplyImpl(ir::Graph *graph) const {
REGISTER_PASS(transpose_flatten_concat_fuse_pass, REGISTER_PASS(transpose_flatten_concat_fuse_pass,
paddle::framework::ir::TransposeFlattenConcatFusePass); paddle::framework::ir::TransposeFlattenConcatFusePass);
REGISTER_PASS_CAPABILITY(transpose_flatten_concat_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("transpose", 0)
.EQ("transpose2", 0)
.EQ("flatten", 0)
.EQ("concat", 0)
.EQ("fusion_transpose_flatten_concat", 0));
...@@ -69,7 +69,8 @@ class OpInfo { ...@@ -69,7 +69,8 @@ class OpInfo {
const OpCreator& Creator() const { const OpCreator& Creator() const {
PADDLE_ENFORCE_NOT_NULL(creator_, PADDLE_ENFORCE_NOT_NULL(creator_,
"Operator's Creator has not been registered"); platform::errors::NotFound(
"Operator's Creator has not been registered."));
return creator_; return creator_;
} }
...@@ -79,11 +80,12 @@ class OpInfo { ...@@ -79,11 +80,12 @@ class OpInfo {
std::string type = proto_ ? proto_->type() : "unknown"; std::string type = proto_ ? proto_->type() : "unknown";
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
grad_op_maker_, grad_op_maker_,
platform::errors::NotFound(
"Operator %s's GradOpMaker has not been " "Operator %s's GradOpMaker has not been "
"registered.\nPlease check whether %s_op has " "registered.\nPlease check whether (%s) operator has "
"grad_op.\nIf not, please set stop_gradient to True " "gradient operator.\nIf not, please set stop_gradient to be True "
"for its input and output variables using var.stop_gradient=True.", "for its input and output variables using var.stop_gradient=True.",
type.c_str(), type.c_str()); type.c_str(), type.c_str()));
return grad_op_maker_; return grad_op_maker_;
} }
...@@ -100,11 +102,12 @@ class OpInfo { ...@@ -100,11 +102,12 @@ class OpInfo {
std::string type = proto_ ? proto_->type() : "unknown"; std::string type = proto_ ? proto_->type() : "unknown";
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
dygraph_grad_op_maker_, dygraph_grad_op_maker_,
platform::errors::NotFound(
"Operator %s's DygraphGradOpMaker has not been " "Operator %s's DygraphGradOpMaker has not been "
"registered.\nPlease check whether %s_op has " "registered.\nPlease check whether (%s) operator has "
"grad_op.\nIf not, please set stop_gradient to True " "gradient operator.\nIf not, please set stop_gradient to be True "
"for its input and output variables using var.stop_gradient=True.", "for its input and output variables using var.stop_gradient=True.",
type.c_str(), type.c_str()); type.c_str(), type.c_str()));
return dygraph_grad_op_maker_; return dygraph_grad_op_maker_;
} }
...@@ -130,14 +133,17 @@ class OpInfoMap { ...@@ -130,14 +133,17 @@ class OpInfoMap {
} }
void Insert(const std::string& type, const OpInfo& info) { void Insert(const std::string& type, const OpInfo& info) {
PADDLE_ENFORCE(!Has(type), "Operator %s has been registered", type); PADDLE_ENFORCE_NE(Has(type), true,
platform::errors::AlreadyExists(
"Operator (%s) has been registered.", type));
map_.insert({type, info}); map_.insert({type, info});
} }
const OpInfo& Get(const std::string& type) const { const OpInfo& Get(const std::string& type) const {
auto op_info_ptr = GetNullable(type); auto op_info_ptr = GetNullable(type);
PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not been registered", PADDLE_ENFORCE_NOT_NULL(
type); op_info_ptr,
platform::errors::NotFound("Operator (%s) is not registered.", type));
return *op_info_ptr; return *op_info_ptr;
} }
......
...@@ -33,10 +33,18 @@ size_t OpKernelType::Hash::operator()(const OpKernelType& key) const { ...@@ -33,10 +33,18 @@ size_t OpKernelType::Hash::operator()(const OpKernelType& key) const {
cur_loc += OpKernelType::kLibBits; cur_loc += OpKernelType::kLibBits;
int customized_value = key.customized_type_value_; int customized_value = key.customized_type_value_;
PADDLE_ENFORCE(customized_value < (1 << OpKernelType::kCustomizeBits)); PADDLE_ENFORCE_LT(customized_value, (1 << OpKernelType::kCustomizeBits),
platform::errors::Unavailable(
"Too many custom OpKernel attribute values, expected "
"maximum value is %d, received value is %d.",
(1 << OpKernelType::kCustomizeBits), customized_value));
customized_value = customized_value << cur_loc; customized_value = customized_value << cur_loc;
cur_loc += OpKernelType::kCustomizeBits; cur_loc += OpKernelType::kCustomizeBits;
PADDLE_ENFORCE(cur_loc < 64); PADDLE_ENFORCE_LT(cur_loc, 64,
platform::errors::Unavailable(
"Too many OpKernel attribute values, expected maximum "
"value is 64, received value is %d.",
cur_loc));
std::hash<int> hasher; std::hash<int> hasher;
return hasher(place + data_type + data_layout + library_type + return hasher(place + data_type + data_layout + library_type +
......
...@@ -43,7 +43,9 @@ OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput( ...@@ -43,7 +43,9 @@ OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
std::unordered_set<std::string> names; std::unordered_set<std::string> names;
auto checker = [&](const std::string& name) { auto checker = [&](const std::string& name) {
PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name); PADDLE_ENFORCE_EQ(
names.count(name), 0,
platform::errors::AlreadyExists("Attribute [%s] is duplicated.", name));
names.insert(name); names.insert(name);
}; };
for (auto& attr : proto_->attrs()) { for (auto& attr : proto_->attrs()) {
......
...@@ -54,9 +54,10 @@ class Registrar { ...@@ -54,9 +54,10 @@ class Registrar {
template <typename... ARGS> template <typename... ARGS>
struct OperatorRegistrar : public Registrar { struct OperatorRegistrar : public Registrar {
explicit OperatorRegistrar(const char* op_type) { explicit OperatorRegistrar(const char* op_type) {
if (OpInfoMap::Instance().Has(op_type)) { PADDLE_ENFORCE_EQ(
PADDLE_THROW("'%s' is registered more than once.", op_type); OpInfoMap::Instance().Has(op_type), false,
} platform::errors::AlreadyExists(
"Operator '%s' is registered more than once.", op_type));
static_assert(sizeof...(ARGS) != 0, static_assert(sizeof...(ARGS) != 0,
"OperatorRegistrar should be invoked at least by OpClass"); "OperatorRegistrar should be invoked at least by OpClass");
OpInfo info; OpInfo info;
......
...@@ -58,7 +58,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -58,7 +58,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
AddInput("input", "input of cosine op").AsDuplicable(); AddInput("input", "input of cosine op").AsDuplicable();
AddOutput("output", "output of cosine op").AsIntermediate(); AddOutput("output", "output of cosine op").AsIntermediate();
auto my_checker = [](int i) { auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); PADDLE_ENFORCE_EQ(i % 2, 0, platform::errors::InvalidArgument(
"'test_attr' must be even!"));
}; };
AddAttr<int>("test_attr", "a simple test attribute") AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker); .AddCustomChecker(my_checker);
......
...@@ -133,6 +133,9 @@ class OpVersion { ...@@ -133,6 +133,9 @@ class OpVersion {
checkpoints_.push_back(Checkpoint({note, op_version_desc})); checkpoints_.push_back(Checkpoint({note, op_version_desc}));
return *this; return *this;
} }
uint32_t GetVersionID() const {
return static_cast<uint32_t>(checkpoints_.size());
}
private: private:
struct Checkpoint { struct Checkpoint {
...@@ -149,13 +152,21 @@ class OpVersionRegistrar { ...@@ -149,13 +152,21 @@ class OpVersionRegistrar {
return instance; return instance;
} }
OpVersion& Register(const std::string& op_type) { OpVersion& Register(const std::string& op_type) {
if (op_version_map_.find(op_type) != op_version_map_.end()) { PADDLE_ENFORCE_EQ(
PADDLE_THROW("'%s' is registered in operator version more than once.", op_version_map_.find(op_type), op_version_map_.end(),
op_type); platform::errors::AlreadyExists(
} "'%s' is registered in operator version more than once.", op_type));
op_version_map_.insert({op_type, OpVersion()}); op_version_map_.insert({op_type, OpVersion()});
return op_version_map_[op_type]; return op_version_map_[op_type];
} }
uint32_t GetVersionID(const std::string& op_type) const {
auto it = op_version_map_.find(op_type);
if (it == op_version_map_.end()) {
return 0;
}
return it->second.GetVersionID();
}
private: private:
std::unordered_map<std::string, OpVersion> op_version_map_; std::unordered_map<std::string, OpVersion> op_version_map_;
...@@ -164,6 +175,125 @@ class OpVersionRegistrar { ...@@ -164,6 +175,125 @@ class OpVersionRegistrar {
OpVersionRegistrar& operator=(const OpVersionRegistrar&) = delete; OpVersionRegistrar& operator=(const OpVersionRegistrar&) = delete;
}; };
class OpVersionComparator {
public:
virtual bool operator()() = 0;
virtual ~OpVersionComparator() = default;
};
#define ADD_OP_VERSION_COMPARATOR(cmp_name, cmp_math) \
class OpVersion##cmp_name##Comparator : public OpVersionComparator { \
public: \
explicit OpVersion##cmp_name##Comparator(const std::string op_name, \
uint32_t target_version) \
: op_name_(op_name), target_version_(target_version) {} \
virtual bool operator()() { \
return OpVersionRegistrar::GetInstance().GetVersionID(op_name_) \
cmp_math target_version_; \
} \
virtual ~OpVersion##cmp_name##Comparator() {} \
\
private: \
std::string op_name_; \
uint32_t target_version_; \
};
ADD_OP_VERSION_COMPARATOR(LE, <=);
ADD_OP_VERSION_COMPARATOR(EQ, ==);
ADD_OP_VERSION_COMPARATOR(GE, >=);
ADD_OP_VERSION_COMPARATOR(NE, !=);
class OpVersionComparatorCombination {
public:
OpVersionComparatorCombination() {}
OpVersionComparatorCombination& LE(const std::string& op_name,
int target_version) {
op_version_comparators_.push_back(std::shared_ptr<OpVersionComparator>(
new OpVersionLEComparator(op_name, target_version)));
return *this;
}
OpVersionComparatorCombination& EQ(const std::string& op_name,
int target_version) {
op_version_comparators_.push_back(std::shared_ptr<OpVersionComparator>(
new OpVersionEQComparator(op_name, target_version)));
return *this;
}
OpVersionComparatorCombination& GE(const std::string& op_name,
int target_version) {
op_version_comparators_.push_back(std::shared_ptr<OpVersionComparator>(
new OpVersionGEComparator(op_name, target_version)));
return *this;
}
OpVersionComparatorCombination& NE(const std::string& op_name,
int target_version) {
op_version_comparators_.push_back(std::shared_ptr<OpVersionComparator>(
new OpVersionNEComparator(op_name, target_version)));
return *this;
}
bool IsMatched() const {
for (const auto& cmp : op_version_comparators_) {
if (!(*cmp)()) {
return false;
}
}
return true;
}
private:
std::vector<std::shared_ptr<OpVersionComparator>> op_version_comparators_;
};
class PassVersionCheckers {
public:
PassVersionCheckers& AddCombination(
const OpVersionComparatorCombination& combinations) {
pass_version_checkers_.push_back(combinations);
return *this;
}
bool IsPassCompatible() const {
if (pass_version_checkers_.empty()) {
return true;
}
for (const auto& checker : pass_version_checkers_) {
if (checker.IsMatched()) {
return true;
}
}
return false;
}
private:
std::vector<OpVersionComparatorCombination> pass_version_checkers_;
};
class PassVersionCheckerRegistrar {
public:
static PassVersionCheckerRegistrar& GetInstance() {
static PassVersionCheckerRegistrar instance;
return instance;
}
PassVersionCheckers& Register(const std::string& pass_name) {
return pass_version_checkers_map_[pass_name];
}
bool IsPassCompatible(const std::string& fuse_pass_name) const {
auto iter = pass_version_checkers_map_.find(fuse_pass_name);
if (iter == pass_version_checkers_map_.end()) {
return true;
}
return iter->second.IsPassCompatible();
}
private:
std::unordered_map<std::string, PassVersionCheckers>
pass_version_checkers_map_;
PassVersionCheckerRegistrar() = default;
PassVersionCheckerRegistrar& operator=(const PassVersionCheckerRegistrar&) =
delete;
};
} // namespace compatible } // namespace compatible
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -173,3 +303,9 @@ class OpVersionRegistrar { ...@@ -173,3 +303,9 @@ class OpVersionRegistrar {
RegisterOpVersion__##op_type = \ RegisterOpVersion__##op_type = \
paddle::framework::compatible::OpVersionRegistrar::GetInstance() \ paddle::framework::compatible::OpVersionRegistrar::GetInstance() \
.Register(#op_type) .Register(#op_type)
#define REGISTER_PASS_CAPABILITY(pass_name) \
static auto RegisterOpPassVersionChecker__##pass_name = \
paddle::framework::compatible::PassVersionCheckerRegistrar:: \
GetInstance() \
.Register(#pass_name)
...@@ -55,6 +55,72 @@ TEST(test_operator_version, test_operator_version) { ...@@ -55,6 +55,72 @@ TEST(test_operator_version, test_operator_version) {
.NewInput("X2", "The second input.") .NewInput("X2", "The second input.")
.NewOutput("Y2", "The second output.")); .NewOutput("Y2", "The second output."));
} }
TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"no_bind_pass"));
REGISTER_PASS_CAPABILITY(test_pass1)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("mul", 1)
.EQ("fc", 0));
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass1"));
REGISTER_PASS_CAPABILITY(test_pass2)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("mul", 0)
.NE("fc", 0));
ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass2"));
REGISTER_PASS_CAPABILITY(test_pass3)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("mul", 0)
.NE("fc", 0))
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("mul", 1)
.EQ("fc", 0));
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass3"));
REGISTER_PASS_CAPABILITY(test_pass4)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("test__", 5)
.EQ("fc", 0));
ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass4"));
REGISTER_PASS_CAPABILITY(test_pass5)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("test__", 4)
.EQ("fc", 0));
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass5"));
REGISTER_PASS_CAPABILITY(test_pass6)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("test__", 4)
.EQ("fc", 0));
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass6"));
REGISTER_PASS_CAPABILITY(test_pass7)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.NE("test__", 4)
.EQ("fc", 0));
ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass7"));
}
} // namespace compatible } // namespace compatible
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册