“90c06ac665b79647dec205da0a50e8cffd8b50b7”上不存在“doc/design/functions_operators_layers.html”
提交 b62406bc 编写于 作者: R ReeseWang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into trt_stack_op, test=develop

...@@ -79,7 +79,6 @@ find_package(Threads REQUIRED) ...@@ -79,7 +79,6 @@ find_package(Threads REQUIRED)
include(simd) include(simd)
################################ Exposed Configurations ####################################### ################################ Exposed Configurations #######################################
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
option(WITH_PYTHON "Compile PaddlePaddle with python interpreter" ON) option(WITH_PYTHON "Compile PaddlePaddle with python interpreter" ON)
option(WITH_TESTING "Compile PaddlePaddle with unit testing" OFF) option(WITH_TESTING "Compile PaddlePaddle with unit testing" OFF)
...@@ -107,6 +106,7 @@ option(SANITIZER_TYPE "Choose the type of sanitizer, options are: Address, Leak, ...@@ -107,6 +106,7 @@ option(SANITIZER_TYPE "Choose the type of sanitizer, options are: Address, Leak,
option(WITH_LITE "Compile Paddle Fluid with Lite Engine" OFF) option(WITH_LITE "Compile Paddle Fluid with Lite Engine" OFF)
option(WITH_NCCL "Compile PaddlePaddle with NCCL support" ON) option(WITH_NCCL "Compile PaddlePaddle with NCCL support" ON)
option(WITH_CRYPTO "Compile PaddlePaddle with crypto support" ON) option(WITH_CRYPTO "Compile PaddlePaddle with crypto support" ON)
option(WITH_ARM "Compile PaddlePaddle with arm support" OFF)
# PY_VERSION # PY_VERSION
if(NOT PY_VERSION) if(NOT PY_VERSION)
...@@ -168,6 +168,9 @@ if(WITH_BRPC_RDMA) ...@@ -168,6 +168,9 @@ if(WITH_BRPC_RDMA)
endif() endif()
endif() endif()
# lite subgraph compilation depends on CUDNN_ROOT,
# so include(cudnn) needs to be in front of include(third_party/lite)
include(cudnn) # set cudnn libraries, must before configure
include(third_party) # download, build, install third_party include(third_party) # download, build, install third_party
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
...@@ -187,7 +190,6 @@ if(NOT WIN32) ...@@ -187,7 +190,6 @@ if(NOT WIN32)
endif() endif()
include(flags) # set paddle compile flags include(flags) # set paddle compile flags
include(cudnn) # set cudnn libraries, must before configure
if(WITH_GPU) if(WITH_GPU)
include(cuda) include(cuda)
...@@ -213,6 +215,12 @@ if(WITH_AMD_GPU) ...@@ -213,6 +215,12 @@ if(WITH_AMD_GPU)
include(hip) include(hip)
endif(WITH_AMD_GPU) endif(WITH_AMD_GPU)
if(WITH_ARM)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
add_definitions(-DPADDLE_WITH_ARM)
endif()
set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build") set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
......
...@@ -16,10 +16,6 @@ if(NOT WITH_PYTHON) ...@@ -16,10 +16,6 @@ if(NOT WITH_PYTHON)
add_definitions(-DPADDLE_NO_PYTHON) add_definitions(-DPADDLE_NO_PYTHON)
endif(NOT WITH_PYTHON) endif(NOT WITH_PYTHON)
if(WITH_DSO)
add_definitions(-DPADDLE_USE_DSO)
endif(WITH_DSO)
if(WITH_TESTING) if(WITH_TESTING)
add_definitions(-DPADDLE_WITH_TESTING) add_definitions(-DPADDLE_WITH_TESTING)
endif(WITH_TESTING) endif(WITH_TESTING)
...@@ -70,10 +66,6 @@ endif() ...@@ -70,10 +66,6 @@ endif()
if(WITH_GPU) if(WITH_GPU)
add_definitions(-DPADDLE_WITH_CUDA) add_definitions(-DPADDLE_WITH_CUDA)
add_definitions(-DEIGEN_USE_GPU) add_definitions(-DEIGEN_USE_GPU)
# The compiler fully support const expressions since c++14,
# but Eigen use some const expressions such as std::max and std::min, which are not supported in c++11
# use following definition to set EIGEN_HAS_CONSTEXPR=0 to avoid compilation error in c++11
add_definitions(-DEIGEN_MAX_CPP_VER=11)
FIND_PACKAGE(CUDA REQUIRED) FIND_PACKAGE(CUDA REQUIRED)
......
...@@ -188,12 +188,6 @@ endif() ...@@ -188,12 +188,6 @@ endif()
add_definitions("-DPADDLE_CUDA_BINVER=\"${CUDA_VERSION_MAJOR}${CUDA_VERSION_MINOR}\"") add_definitions("-DPADDLE_CUDA_BINVER=\"${CUDA_VERSION_MAJOR}${CUDA_VERSION_MINOR}\"")
if(NOT WITH_DSO)
if(WIN32)
set_property(GLOBAL PROPERTY CUDA_MODULES ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY} ${CUDA_cusolver_LIBRARY})
endif(WIN32)
endif(NOT WITH_DSO)
# setting nvcc arch flags # setting nvcc arch flags
select_nvcc_arch_flags(NVCC_FLAGS_EXTRA) select_nvcc_arch_flags(NVCC_FLAGS_EXTRA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS_EXTRA}") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS_EXTRA}")
......
...@@ -49,9 +49,14 @@ elseif(LINUX) ...@@ -49,9 +49,14 @@ elseif(LINUX)
# refer to: https://gitlab.com/libeigen/eigen/-/blob/4da2c6b1974827b1999bab652a3d4703e1992d26/Eigen/src/Core/arch/SSE/PacketMath.h#L33-60 # refer to: https://gitlab.com/libeigen/eigen/-/blob/4da2c6b1974827b1999bab652a3d4703e1992d26/Eigen/src/Core/arch/SSE/PacketMath.h#L33-60
# add -fabi-version=4 could avoid above error, but will cause "double free corruption" when compile with gcc8 # add -fabi-version=4 could avoid above error, but will cause "double free corruption" when compile with gcc8
# so use following patch to solve compilation error with different version of gcc. # so use following patch to solve compilation error with different version of gcc.
file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/Geometry_SSE.h native_src) file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/Geometry_SSE.h native_src1)
file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/Eigen/src/Geometry/arch/Geometry_SSE.h native_dst) file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/Eigen/src/Geometry/arch/Geometry_SSE.h native_dst1)
set(EIGEN_PATCH_COMMAND cp ${native_src} ${native_dst}) # The compiler fully support const expressions since c++14,
# but Eigen use some const expressions such as std::max and std::min, which are not supported in c++11
# add patch to avoid compilation error in c++11
file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/MathFunctions.h native_src2)
file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/Eigen/src/Core/MathFunctions.h native_dst2)
set(EIGEN_PATCH_COMMAND cp ${native_src1} ${native_dst1} && cp ${native_src2} ${native_dst2})
endif() endif()
set(EIGEN_INCLUDE_DIR ${EIGEN_SOURCE_DIR}) set(EIGEN_INCLUDE_DIR ${EIGEN_SOURCE_DIR})
......
...@@ -25,7 +25,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) ...@@ -25,7 +25,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite) set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite)
if(NOT LITE_GIT_TAG) if(NOT LITE_GIT_TAG)
set(LITE_GIT_TAG 34c29406c27ee00cef033a98887403443eb2565f) set(LITE_GIT_TAG ab8af5c4b4dc5b40217633e0aa436315912d7b53)
endif() endif()
if(NOT CUDA_ARCH_NAME) if(NOT CUDA_ARCH_NAME)
...@@ -93,6 +93,7 @@ function(external_lite_static_libs alias path) ...@@ -93,6 +93,7 @@ function(external_lite_static_libs alias path)
endfunction() endfunction()
external_lite_static_libs(lite_full_static ${LITE_BINARY_DIR}/inference_lite_lib/cxx/lib/libpaddle_full_api_shared.so) external_lite_static_libs(lite_full_static ${LITE_BINARY_DIR}/inference_lite_lib/cxx/lib/libpaddle_full_api_shared.so)
set(LITE_SHARED_LIB ${LITE_BINARY_DIR}/inference_lite_lib/cxx/lib/libpaddle_full_api_shared.so)
add_definitions(-DPADDLE_WITH_LITE) add_definitions(-DPADDLE_WITH_LITE)
add_definitions(-DLITE_WITH_LOG) add_definitions(-DLITE_WITH_LOG)
...@@ -36,28 +36,12 @@ SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/${LIBDIR ...@@ -36,28 +36,12 @@ SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/${LIBDIR
INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR}) # For MKLDNN code to include internal headers. INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR}) # For MKLDNN code to include internal headers.
IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
SET(MKLDNN_DEPENDS ${MKLML_PROJECT})
MESSAGE(STATUS "Build MKLDNN with MKLML ${MKLML_ROOT}")
ELSE()
MESSAGE(STATUS "Build MKLDNN without MKLML")
ENDIF()
IF(NOT WIN32) IF(NOT WIN32)
SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-error=array-bounds") SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-error=array-bounds")
SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value") SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value")
SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}") SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}")
SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}") SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}")
IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
# Force libmkldnn.so to link libiomp5.so (provided by intel mkl) instead of libgomp.so (provided by gcc),
# since core_avx.so links libiomp5.so
set(MKLDNN_SHARED_LINKER_FLAG "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--no-as-needed -L${MKLML_LIB_DIR} -liomp5")
set(FORBID "-fopenmp")
ELSE()
set(MKLDNN_SHARED_LINKER_FLAG "${CMAKE_SHARED_LINKER_FLAGS}")
set(FORBID "")
ENDIF()
ELSE() ELSE()
SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} /EHsc") SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} /EHsc")
ENDIF(NOT WIN32) ENDIF(NOT WIN32)
...@@ -91,8 +75,6 @@ ExternalProject_Add( ...@@ -91,8 +75,6 @@ ExternalProject_Add(
-DCMAKE_C_FLAGS=${MKLDNN_CFLAG} -DCMAKE_C_FLAGS=${MKLDNN_CFLAG}
-DCMAKE_CXX_FLAGS=${MKLDNN_CXXFLAG} -DCMAKE_CXX_FLAGS=${MKLDNN_CXXFLAG}
-DDNNL_BUILD_TESTS=OFF -DDNNL_BUILD_EXAMPLES=OFF -DDNNL_BUILD_TESTS=OFF -DDNNL_BUILD_EXAMPLES=OFF
-DCMAKE_SHARED_LINKER_FLAGS=${MKLDNN_SHARED_LINKER_FLAG}
-DCMAKE_CXX_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS=${FORBID}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR}
) )
if(WIN32) if(WIN32)
......
...@@ -19,6 +19,9 @@ SET(CBLAS_SOURCE_DIR ${THIRD_PARTY_PATH}/openblas/src/extern_openblas) ...@@ -19,6 +19,9 @@ SET(CBLAS_SOURCE_DIR ${THIRD_PARTY_PATH}/openblas/src/extern_openblas)
SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas) SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas)
SET(CBLAS_REPOSITORY https://github.com/xianyi/OpenBLAS.git) SET(CBLAS_REPOSITORY https://github.com/xianyi/OpenBLAS.git)
SET(CBLAS_TAG v0.3.7) SET(CBLAS_TAG v0.3.7)
IF(WITH_ARM)
SET(CBLAS_TAG v0.2.18)
ENDIF()
cache_third_party(extern_openblas cache_third_party(extern_openblas
REPOSITORY ${CBLAS_REPOSITORY} REPOSITORY ${CBLAS_REPOSITORY}
TAG ${CBLAS_TAG} TAG ${CBLAS_TAG}
......
...@@ -187,7 +187,7 @@ set(GPU_COMMON_FLAGS ...@@ -187,7 +187,7 @@ set(GPU_COMMON_FLAGS
-Wno-error=unused-function # Warnings in Numpy Header. -Wno-error=unused-function # Warnings in Numpy Header.
-Wno-error=array-bounds # Warnings in Eigen::array -Wno-error=array-bounds # Warnings in Eigen::array
) )
if (NOT WITH_NV_JETSON) if (NOT WITH_NV_JETSON AND NOT WITH_ARM)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m64") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m64")
endif() endif()
endif(NOT WIN32) endif(NOT WIN32)
......
...@@ -89,6 +89,8 @@ ...@@ -89,6 +89,8 @@
# including binary directory for generated headers. # including binary directory for generated headers.
include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR})
# including io directory for inference lib paddle_api.h
include_directories("${PADDLE_SOURCE_DIR}/paddle/fluid/framework/io")
if(NOT APPLE) if(NOT APPLE)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
......
...@@ -13,10 +13,6 @@ include_directories("/opt/rocm/thrust") ...@@ -13,10 +13,6 @@ include_directories("/opt/rocm/thrust")
set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -fPIC -DPADDLE_WITH_HIP -std=c++11" ) set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -fPIC -DPADDLE_WITH_HIP -std=c++11" )
if(WITH_DSO)
set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DPADDLE_USE_DSO")
endif(WITH_DSO)
if(WITH_TESTING) if(WITH_TESTING)
set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DPADDLE_WITH_TESTING") set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DPADDLE_WITH_TESTING")
endif(WITH_TESTING) endif(WITH_TESTING)
......
...@@ -107,6 +107,11 @@ function(copy_part_of_thrid_party TARGET DST) ...@@ -107,6 +107,11 @@ function(copy_part_of_thrid_party TARGET DST)
SRCS ${GLOG_INCLUDE_DIR} ${GLOG_LIBRARIES} SRCS ${GLOG_INCLUDE_DIR} ${GLOG_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib) DSTS ${dst_dir} ${dst_dir}/lib)
set(dst_dir "${DST}/third_party/install/cryptopp")
copy(${TARGET}
SRCS ${CRYPTOPP_INCLUDE_DIR} ${CRYPTOPP_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib)
set(dst_dir "${DST}/third_party/install/xxhash") set(dst_dir "${DST}/third_party/install/xxhash")
copy(${TARGET} copy(${TARGET}
SRCS ${XXHASH_INCLUDE_DIR} ${XXHASH_LIBRARIES} SRCS ${XXHASH_INCLUDE_DIR} ${XXHASH_LIBRARIES}
...@@ -178,7 +183,10 @@ endif() ...@@ -178,7 +183,10 @@ endif()
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${CMAKE_BINARY_DIR}/paddle/fluid/framework/framework.pb.h SRCS ${CMAKE_BINARY_DIR}/paddle/fluid/framework/framework.pb.h
DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include/internal) DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include/internal)
copy(inference_lib_dist
SRCS ${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io/crypto/cipher.h
DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include/crypto/)
include_directories(${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io)
# CAPI inference library for only inference # CAPI inference library for only inference
set(FLUID_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_c_install_dir" CACHE STRING set(FLUID_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_c_install_dir" CACHE STRING
"A path setting CAPI fluid inference shared") "A path setting CAPI fluid inference shared")
......
...@@ -37,16 +37,10 @@ find_library(TENSORRT_LIBRARY NAMES ${TR_INFER_LIB} ${TR_INFER_RT} ...@@ -37,16 +37,10 @@ find_library(TENSORRT_LIBRARY NAMES ${TR_INFER_LIB} ${TR_INFER_RT}
DOC "Path to TensorRT library.") DOC "Path to TensorRT library.")
if(TENSORRT_INCLUDE_DIR AND TENSORRT_LIBRARY) if(TENSORRT_INCLUDE_DIR AND TENSORRT_LIBRARY)
if(WITH_DSO) set(TENSORRT_FOUND ON)
set(TENSORRT_FOUND ON)
endif(WITH_DSO)
else() else()
set(TENSORRT_FOUND OFF) set(TENSORRT_FOUND OFF)
if(WITH_DSO) message(STATUS "TensorRT is disabled.")
message(WARNING "TensorRT is NOT found when WITH_DSO is ON.")
else(WITH_DSO)
message(STATUS "TensorRT is disabled because WITH_DSO is OFF.")
endif(WITH_DSO)
endif() endif()
if(TENSORRT_FOUND) if(TENSORRT_FOUND)
......
paddle.fluid.optimizer.PipelineOptimizer (paddle.fluid.optimizer.PipelineOptimizer, ('document', '2e55a29dbeb874934f7a1a1af3a22b8c'))
paddle.fluid.optimizer.PipelineOptimizer.__init__ (ArgSpec(args=['self', 'optimizer', 'num_microbatches', 'start_cpu_core_id'], varargs=None, keywords=None, defaults=(1, 0)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.PipelineOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
...@@ -155,22 +155,31 @@ nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) ...@@ -155,22 +155,31 @@ nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
if(WITH_PYTHON) if(WITH_PYTHON)
py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto) py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto)
py_proto_compile(trainer_py_proto SRCS trainer_desc.proto data_feed.proto) py_proto_compile(trainer_py_proto SRCS trainer_desc.proto data_feed.proto)
py_proto_compile(distributed_strategy_py_proto SRCS distributed_strategy.proto)
#Generate an empty \ #Generate an empty \
#__init__.py to make framework_py_proto as a valid python module. #__init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init trainer_py_proto) add_dependencies(framework_py_proto framework_py_proto_init trainer_py_proto distributed_strategy_py_proto)
if (NOT WIN32) if (NOT WIN32)
add_custom_command(TARGET framework_py_proto POST_BUILD add_custom_command(TARGET framework_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fleet/proto
COMMAND ${CMAKE_COMMAND} -E touch ${PADDLE_BINARY_DIR}/python/paddle/fleet/proto/__init__.py
COMMAND cp *.py ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/ COMMAND cp *.py ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/
COMMAND cp distributed_strategy_*.py ${PADDLE_BINARY_DIR}/python/paddle/fleet/proto
COMMENT "Copy generated python proto into directory paddle/fluid/proto." COMMENT "Copy generated python proto into directory paddle/fluid/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
else(NOT WIN32) else(NOT WIN32)
string(REPLACE "/" "\\" proto_dstpath "${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/") string(REPLACE "/" "\\" proto_dstpath "${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/")
string(REPLACE "/" "\\" fleet_proto_dstpath "${PADDLE_BINARY_DIR}/python/paddle/fleet/proto/")
add_custom_command(TARGET framework_py_proto POST_BUILD add_custom_command(TARGET framework_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fleet/proto
COMMAND ${CMAKE_COMMAND} -E touch ${PADDLE_BINARY_DIR}/python/paddle/fleet/proto/__init__.py
COMMAND copy /Y *.py ${proto_dstpath} COMMAND copy /Y *.py ${proto_dstpath}
COMMAND copy /Y distributed_strategy_*.py ${fleet_proto_dstpath}
COMMENT "Copy generated python proto into directory paddle/fluid/proto." COMMENT "Copy generated python proto into directory paddle/fluid/proto."
COMMENT "Copy generated python proto into directory paddle/fleet/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif(NOT WIN32) endif(NOT WIN32)
endif() endif()
......
...@@ -101,6 +101,8 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo ...@@ -101,6 +101,8 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo
DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context)
cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle) cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle)
cc_test(exception_holder_test SRCS exception_holder_test.cc )
set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass fuse_bn_act_pass fuse_elewise_add_act_pass fuse_bn_act_pass
......
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <mutex> // NOLINT
#include <string> #include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -29,15 +31,16 @@ class ExceptionHolder { ...@@ -29,15 +31,16 @@ class ExceptionHolder {
void Catch(std::exception_ptr eptr) { void Catch(std::exception_ptr eptr) {
try { try {
std::rethrow_exception(eptr); std::rethrow_exception(eptr);
} catch (memory::allocation::BadAlloc& exp) {
Catch(exp);
} catch (platform::EOFException& exp) { } catch (platform::EOFException& exp) {
Catch(exp); Catch(exp);
} catch (platform::EnforceNotMet& exp) { } catch (platform::EnforceNotMet& exp) {
Catch(exp); Catch(exp);
} catch (std::exception& ex) { } catch (std::exception& ex) {
PADDLE_THROW(platform::errors::Fatal( Catch(ex);
"Unknown std::exception caught:\n%s.", ex.what()));
} catch (...) { } catch (...) {
PADDLE_THROW(platform::errors::Fatal("Unknown exception caught.")); LOG(FATAL) << "Unknown exception caught.";
} }
} }
...@@ -59,6 +62,15 @@ class ExceptionHolder { ...@@ -59,6 +62,15 @@ class ExceptionHolder {
auto e = *static_cast<platform::EOFException*>(exception_.get()); auto e = *static_cast<platform::EOFException*>(exception_.get());
throw e; throw e;
} }
case kBadAlloc: {
auto e = *static_cast<paddle::memory::allocation::BadAlloc*>(
exception_.get());
throw e;
}
case kBaseException: {
auto e = *static_cast<std::exception*>(exception_.get());
throw e;
}
} }
ClearImpl(); ClearImpl();
} }
...@@ -79,6 +91,12 @@ class ExceptionHolder { ...@@ -79,6 +91,12 @@ class ExceptionHolder {
case kEOF: { case kEOF: {
return "EOF"; return "EOF";
} }
case kBadAlloc: {
return "BadAlloc";
}
case kBaseException: {
return "BaseException";
}
} }
return "unknown"; return "unknown";
} }
...@@ -89,10 +107,31 @@ class ExceptionHolder { ...@@ -89,10 +107,31 @@ class ExceptionHolder {
type_ = kNone; type_ = kNone;
} }
// NOTE: currently in PE, multiple exceptions may occured in multiple
// threads, and the exception that occur later will overwrite that
// occur earlier, but what we want should be the first triggered exception.
// However, EOF exception is lower priority exception and can be overwritten,
// but other exceptions should not be prioritized.
void Catch(const platform::EnforceNotMet& exp) { void Catch(const platform::EnforceNotMet& exp) {
std::lock_guard<std::mutex> lock(mu_); std::lock_guard<std::mutex> lock(mu_);
exception_.reset(new platform::EnforceNotMet(exp)); if (exception_.get() == nullptr || type_ == kEOF) {
type_ = kEnforceNotMet; exception_.reset(new platform::EnforceNotMet(exp));
type_ = kEnforceNotMet;
} else {
VLOG(2) << "Non-first exception is discarded, the error message is"
<< exception_->what();
}
}
void Catch(const memory::allocation::BadAlloc& exp) {
std::lock_guard<std::mutex> lock(mu_);
if (exception_.get() == nullptr || type_ == kEOF) {
exception_.reset(new paddle::memory::allocation::BadAlloc(exp));
type_ = kBadAlloc;
} else {
VLOG(2) << "Non-first exception is discarded, the error message is"
<< exception_->what();
}
} }
void Catch(const platform::EOFException& exp) { void Catch(const platform::EOFException& exp) {
...@@ -101,10 +140,24 @@ class ExceptionHolder { ...@@ -101,10 +140,24 @@ class ExceptionHolder {
if (exception_.get() == nullptr) { if (exception_.get() == nullptr) {
exception_.reset(new platform::EOFException(exp)); exception_.reset(new platform::EOFException(exp));
type_ = kEOF; type_ = kEOF;
} else {
VLOG(2) << "EOFException is skip, the error message of EOFException is "
<< exception_->what();
}
}
void Catch(const std::exception& exp) {
std::lock_guard<std::mutex> lock(mu_);
if (exception_.get() == nullptr || type_ == kEOF) {
exception_.reset(new std::exception(exp));
type_ = kBaseException;
} else {
VLOG(2) << "Non-first exception is discarded, the error message is"
<< exception_->what();
} }
} }
enum ExceptionType { kNone, kEnforceNotMet, kEOF }; enum ExceptionType { kNone, kEnforceNotMet, kEOF, kBadAlloc, kBaseException };
ExceptionType type_{kNone}; ExceptionType type_{kNone};
std::unique_ptr<std::exception> exception_; std::unique_ptr<std::exception> exception_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/exception_holder.h"
#include <memory>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/memory/allocation/allocator.h"
namespace paddle {
namespace framework {
namespace details {
namespace f = paddle::framework;
namespace p = paddle::platform;
TEST(ExceptionHolderTester, TestEnforceNotMetCatch) {
ExceptionHolder exception_holder;
try {
throw platform::EnforceNotMet("enforce not met test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "EnforceNotMet");
bool catch_enforce_not_met = false;
try {
exception_holder.ReThrow();
} catch (platform::EnforceNotMet& ex) {
catch_enforce_not_met = true;
} catch (...) {
catch_enforce_not_met = false;
}
ASSERT_TRUE(catch_enforce_not_met);
}
TEST(ExceptionHolderTester, TestBadAllocCatch) {
ExceptionHolder exception_holder;
try {
throw memory::allocation::BadAlloc("bad alloc test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "BadAlloc");
bool catch_bad_alloc = false;
try {
exception_holder.ReThrow();
} catch (memory::allocation::BadAlloc& ex) {
catch_bad_alloc = true;
} catch (...) {
catch_bad_alloc = false;
}
ASSERT_TRUE(catch_bad_alloc);
}
TEST(ExceptionHolderTester, TestBaseExpceptionCatch) {
ExceptionHolder exception_holder;
try {
throw std::exception();
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "BaseException");
bool catch_base_exception = false;
try {
exception_holder.ReThrow();
} catch (std::exception& ex) {
catch_base_exception = true;
} catch (...) {
catch_base_exception = false;
}
ASSERT_TRUE(catch_base_exception);
}
TEST(ExceptionHolderTester, TestExceptionReplace) {
ExceptionHolder exception_holder;
try {
throw platform::EnforceNotMet("enforce not met test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "EnforceNotMet");
try {
throw std::exception();
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "EnforceNotMet");
try {
throw memory::allocation::BadAlloc("bad alloc test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "EnforceNotMet");
try {
throw platform::EOFException("eof test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_EQ(exception_holder.Type(), "EnforceNotMet");
exception_holder.Clear();
try {
throw memory::allocation::BadAlloc("bad alloc test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "BadAlloc");
try {
throw platform::EnforceNotMet("enforce not met test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "BadAlloc");
}
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -269,7 +269,14 @@ void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) { ...@@ -269,7 +269,14 @@ void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
void FastThreadedSSAGraphExecutor::ExecutionFinal( void FastThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) { std::vector<OpHandleBase *> *fetch_ops) {
VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it"; VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it";
ClearFetchOp(graph_, fetch_ops); // NOTE: If a new exception occurs in this ClearFetchOp operation, it will
// cause the loss of exception triggered firstly not thrown.
// Instead, the cleanup operation should only be performed when an EOF
// exception is caught. If other exceptions are triggered, the ClearFetchOp
// should not be continued.
if (exception_.Type() == "EOF") {
ClearFetchOp(graph_, fetch_ops);
}
exception_.ReThrow(); exception_.ReThrow();
} }
......
...@@ -36,7 +36,7 @@ OpHandleBase::~OpHandleBase() PADDLE_MAY_THROW { ...@@ -36,7 +36,7 @@ OpHandleBase::~OpHandleBase() PADDLE_MAY_THROW {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
for (auto &ev : events_) { for (auto &ev : events_) {
if (ev.second) { if (ev.second) {
PADDLE_ENFORCE(cudaEventDestroy(ev.second)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(ev.second));
} }
} }
#endif #endif
......
...@@ -51,10 +51,6 @@ bool CheckValidOutput(LoDTensor* tensor, size_t batch_size); ...@@ -51,10 +51,6 @@ bool CheckValidOutput(LoDTensor* tensor, size_t batch_size);
class FleetWrapper; class FleetWrapper;
#define SEC_LOG \
VLOG(3) << "[s" << section_id_ << "p" << pipeline_id_ << "t" << thread_id_ \
<< "]: "
class PullDenseWorker { class PullDenseWorker {
public: public:
virtual ~PullDenseWorker() {} virtual ~PullDenseWorker() {}
...@@ -311,40 +307,9 @@ class DownpourWorkerOpt : public DownpourWorker { ...@@ -311,40 +307,9 @@ class DownpourWorkerOpt : public DownpourWorker {
}; };
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
using ScopeQueue = operators::reader::BlockingQueue<Scope*>;
class SyncFunctor {
public:
SyncFunctor(int rank_id, int rank_num, int sync_steps);
virtual ~SyncFunctor() {}
void SetSyncParam(const std::vector<std::string>& sync_param) {
sync_param_ = &sync_param;
}
void SetNcclCtxMap(platform::NCCLContextMap* nccl_ctx_map) {
nccl_ctx_map_ = nccl_ctx_map;
}
int operator()(Scope* scope);
static std::vector<Scope*> pipeline_scopes_;
static uint64_t sync_flag_;
protected:
const int rank_id_;
const int rank_num_;
const std::vector<std::string>* sync_param_ = nullptr;
platform::NCCLContextMap* nccl_ctx_map_ = nullptr;
uint64_t sync_signal_;
const int sync_steps_;
int counter_;
void Synchronize();
};
class SectionWorker : public DeviceWorker { class SectionWorker : public DeviceWorker {
public: public:
SectionWorker() {} SectionWorker() { local_batch_id_ = 0; }
~SectionWorker() override {} ~SectionWorker() override {}
void Initialize(const TrainerDesc& desc) override; void Initialize(const TrainerDesc& desc) override;
...@@ -360,50 +325,39 @@ class SectionWorker : public DeviceWorker { ...@@ -360,50 +325,39 @@ class SectionWorker : public DeviceWorker {
const platform::Place& place() const { return place_; } const platform::Place& place() const { return place_; }
void SetSectionIndex(int section_id) { section_id_ = section_id; } void SetSectionIndex(int section_id) { section_id_ = section_id; }
void SetDeviceIndex(int tid) override { pipeline_id_ = tid; } void SetDeviceIndex(int tid) override {}
void SetThreadIndex(int thread_id) { thread_id_ = thread_id; } void SetThreadIndex(int thread_id) { thread_id_ = thread_id; }
void SetVarNames(const std::vector<std::string>& in_var_names, void SetMicrobatchNum(int num) { num_microbatches_ = num; }
const std::vector<std::string>& out_var_names) { void SetMicrobatchScopes(const std::vector<Scope*>& scope) {
in_var_names_ = &in_var_names; microbatch_scopes_ = scope;
out_var_names_ = &out_var_names;
}
void SetScopeQueue(ScopeQueue* in_scope_queue, ScopeQueue* out_scope_queue) {
in_scope_queue_ = in_scope_queue;
out_scope_queue_ = out_scope_queue;
} }
void SetCountMutex(std::mutex* mutex) { worker_count_mutex_ = mutex; } void SetMinibatchScope(const Scope* scope) { minibatch_scope_ = scope; }
void SetWorkerCount(int* worker_count) { worker_count_ = worker_count; } void SetSkipVars(const std::vector<std::string>& skip_vars) {
void SetSectionNum(int section_num) { section_num_ = section_num; } skip_vars_ = skip_vars;
void SetPipelineNum(int pipeline_num) { pipeline_num_ = pipeline_num; }
void SetNextSectionPlace(const paddle::platform::Place& place) {
next_section_place_ = place;
} }
SyncFunctor* sync_func_ = nullptr;
void SetSyncFunctor(SyncFunctor* sync_func) { sync_func_ = sync_func; }
static std::atomic<int> cpu_id_; static std::atomic<int> cpu_id_;
protected: protected:
void AutoSetCPUAffinity(bool reuse); void AutoSetCPUAffinity(bool reuse);
int section_id_; int section_id_;
int pipeline_id_;
int section_num_;
int pipeline_num_;
int thread_id_; int thread_id_;
// This worker will consume scope from in_scope_queue_ int num_microbatches_;
// and produce scope to out_scope_queue_ std::vector<Scope*> microbatch_scopes_;
ScopeQueue* in_scope_queue_ = nullptr; std::vector<std::string> skip_vars_;
ScopeQueue* out_scope_queue_ = nullptr; const Scope* minibatch_scope_;
const std::vector<std::string>* in_var_names_ = nullptr;
const std::vector<std::string>* out_var_names_ = nullptr;
std::mutex* worker_count_mutex_ = nullptr;
int* worker_count_ = nullptr;
paddle::platform::Place next_section_place_;
std::vector<std::unique_ptr<OperatorBase>> ops_; std::vector<std::unique_ptr<OperatorBase>> ops_;
static std::mutex thread_mutex;
static std::condition_variable thread_condition;
static bool threads_completed;
std::shared_ptr<framework::ProgramDesc> program_;
static uint64_t batch_id_;
uint64_t local_batch_id_;
platform::DeviceContext* dev_ctx_ = nullptr; platform::DeviceContext* dev_ctx_ = nullptr;
}; };
#endif #endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package paddle.fleet;
enum Mode {
COLLECTIVE = 1;
PS = 2;
PIPELINE = 3;
HETER = 4; // support XPU and GPU computing server
}
message DistributedStrategy {
optional Mode mode = 1 [ default = COLLECTIVE ]; // just for serialization
// collective training strategy
optional bool amp = 2 [ default = false ];
optional int32 amp_loss_scaling = 3 [ default = 32768 ];
optional bool recompute = 4 [ default = false ];
repeated string recompute_checkpoints = 5;
optional bool localsgd = 6 [ default = false ];
optional int32 localsgd_k_step = 7 [ default = 4 ];
optional bool dgc = 8 [ default = false ];
optional bool hierachical_allreduce = 9 [ default = false ];
optional int32 nccl_comm_num = 10 [ default = 1 ];
optional bool gradient_merge = 11 [ default = false ];
optional int32 gradient_merge_k_step = 12 [ default = 1 ];
optional bool sequential_execution = 13 [ default = false ];
optional bool enable_backward_optimizer_op_deps = 14 [ default = true ];
optional bool lars = 15 [ default = false ];
optional bool lamb = 16 [ default = false ];
optional bool fuse_elewise_add_act_ops = 17 [ default = false ];
optional bool fuse_bn_act_ops = 18 [ default = false ];
optional bool enable_auto_fusion = 19 [ default = false ];
optional bool fuse_relu_depthwise_conv = 20 [ default = false ];
optional bool enable_inplace = 21 [ default = false ];
optional bool fuse_all_reduce_ops = 22 [ default = false ];
optional int32 num_iteration_per_drop_scope = 23 [ default = 1 ];
optional bool sync_batch_norm = 24 [ default = false ];
optional bool fuse_all_optimizer_ops = 25 [ default = false ];
// pipeline training
optional bool pipeline = 101 [ default = false ];
optional int32 pipeline_micro_batch = 102;
// parameter server training
optional bool sync = 201 [ default = false ];
optional bool async = 202 [ default = true ];
optional int32 async_k_step = 203 [ default = -1 ];
optional int32 max_merge_var_num = 204 [ default = 1 ];
optional int32 send_queue_size = 205 [ default = 16 ];
optional bool independent_recv_thread = 206 [ default = false ];
optional int32 min_send_grad_num_before_recv = 207 [ default = 1 ];
optional int32 thread_pool_size = 208 [ default = 1 ];
optional int32 send_wait_times = 209 [ default = 1 ];
optional bool runtime_split_send_recv = 210 [ default = false ];
optional bool use_thread_barrier = 211 [ default = false ];
// elastic deep learning strategies
optional bool elastic = 301 [ default = false ];
// auto parallel
optional bool auto = 401 [ default = false ];
}
message DistributedJobInfo {
optional int32 worker_num = 1;
optional int32 server_num = 2;
repeated string worker_ips = 3;
repeated string server_endpoints = 4;
optional string origin_startup = 5;
optional string origin_main = 6; // without backpropagation and optimization
optional string distributed_main = 7; // with backpropagation and optimization
optional string optimizer_name = 8; // optimizer name
optional DistributedStrategy strategy = 101;
}
...@@ -23,9 +23,6 @@ ...@@ -23,9 +23,6 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <size_t EMBEDX_DIM, size_t EXPAND_EMBED_DIM> template <size_t EMBEDX_DIM, size_t EXPAND_EMBED_DIM>
__global__ void PullCopy( __global__ void PullCopy(
......
...@@ -18,9 +18,9 @@ limitations under the License. */ ...@@ -18,9 +18,9 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <fstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/io/crypto/cipher.h" #include "paddle/fluid/framework/io/crypto/cipher.h"
#include "paddle/fluid/framework/io/crypto/cipher_utils.h" #include "paddle/fluid/framework/io/crypto/cipher_utils.h"
......
...@@ -16,7 +16,9 @@ ...@@ -16,7 +16,9 @@
#include "paddle/fluid/framework/io/crypto/aes_cipher.h" #include "paddle/fluid/framework/io/crypto/aes_cipher.h"
#include "paddle/fluid/framework/io/crypto/cipher_utils.h" #include "paddle/fluid/framework/io/crypto/cipher_utils.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#ifdef ON_INFER
#include "paddle/fluid/inference/api/paddle_api.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -57,4 +59,9 @@ std::shared_ptr<Cipher> CipherFactory::CreateCipher( ...@@ -57,4 +59,9 @@ std::shared_ptr<Cipher> CipherFactory::CreateCipher(
} }
} // namespace framework } // namespace framework
#ifdef ON_INFER
std::shared_ptr<framework::Cipher> MakeCipher(const std::string& config_file) {
return framework::CipherFactory::CreateCipher(config_file);
}
#endif
} // namespace paddle } // namespace paddle
...@@ -46,6 +46,5 @@ class CipherFactory { ...@@ -46,6 +46,5 @@ class CipherFactory {
CipherFactory() = default; CipherFactory() = default;
static std::shared_ptr<Cipher> CreateCipher(const std::string& config_file); static std::shared_ptr<Cipher> CreateCipher(const std::string& config_file);
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,11 +14,9 @@ ...@@ -14,11 +14,9 @@
#pragma once #pragma once
#include <sstream>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -21,7 +21,7 @@ namespace framework { ...@@ -21,7 +21,7 @@ namespace framework {
std::shared_ptr<FILE> shell_fopen(const std::string& path, std::shared_ptr<FILE> shell_fopen(const std::string& path,
const std::string& mode) { const std::string& mode) {
#if defined _WIN32 || defined __APPLE__ || defined PADDLE_ARM #if defined(_WIN32) || defined(__APPLE__) || defined(PADDLE_ARM)
return nullptr; return nullptr;
#else #else
if (shell_verbose()) { if (shell_verbose()) {
...@@ -48,7 +48,7 @@ std::shared_ptr<FILE> shell_fopen(const std::string& path, ...@@ -48,7 +48,7 @@ std::shared_ptr<FILE> shell_fopen(const std::string& path,
// The implementation is async signal safe // The implementation is async signal safe
// Mostly copy from CPython code // Mostly copy from CPython code
static int close_open_fds_internal() { static int close_open_fds_internal() {
#if defined _WIN32 || defined __APPLE__ || defined PADDLE_ARM #if defined(_WIN32) || defined(__APPLE__) || defined(PADDLE_ARM)
return 0; return 0;
#else #else
struct linux_dirent { struct linux_dirent {
...@@ -103,8 +103,9 @@ static int close_open_fds_internal() { ...@@ -103,8 +103,9 @@ static int close_open_fds_internal() {
} }
static int shell_popen_fork_internal(const char* real_cmd, bool do_read, static int shell_popen_fork_internal(const char* real_cmd, bool do_read,
int parent_end, int child_end) { int parent_end, int child_end,
#if defined _WIN32 || defined __APPLE__ bool redirect_stderr = false) {
#if defined(_WIN32) || defined(__APPLE__) || defined(PADDLE_ARM)
return 0; return 0;
#else #else
int child_pid = -1; int child_pid = -1;
...@@ -125,18 +126,41 @@ static int shell_popen_fork_internal(const char* real_cmd, bool do_read, ...@@ -125,18 +126,41 @@ static int shell_popen_fork_internal(const char* real_cmd, bool do_read,
if (child_end != child_std_end) { if (child_end != child_std_end) {
PCHECK(dup2(child_end, child_std_end) == child_std_end); PCHECK(dup2(child_end, child_std_end) == child_std_end);
if (redirect_stderr && do_read) {
PCHECK(dup2(child_end, 2) == 2);
}
close(child_end); close(child_end);
} }
close_open_fds_internal(); close_open_fds_internal();
PCHECK(execl("/bin/bash", "bash", "-c", real_cmd, NULL) >= 0); PCHECK(execl("/bin/bash", "bash", "-c", real_cmd, NULL) >= 0);
exit(127); // Note: just for compilation. the child don't run this line.
_exit(0);
#endif #endif
} }
static int read_from_pipe(FILE* fp, std::string* output) {
char buf[4096];
while (1) {
int n = fread(buf, 1, 4096, fp);
if (n <= 0) {
break;
}
output->append(buf, n);
}
if (!feof(fp)) {
return -1;
}
return 0;
}
std::shared_ptr<FILE> shell_popen(const std::string& cmd, std::shared_ptr<FILE> shell_popen(const std::string& cmd,
const std::string& mode, int* err_no) { const std::string& mode, int* err_no,
#if defined _WIN32 || defined __APPLE__ int* status, bool redirect_stderr) {
#if defined(_WIN32) || defined(__APPLE__) || defined(PADDLE_ARM)
return nullptr; return nullptr;
#else #else
bool do_read = mode == "r"; bool do_read = mode == "r";
...@@ -146,9 +170,7 @@ std::shared_ptr<FILE> shell_popen(const std::string& cmd, ...@@ -146,9 +170,7 @@ std::shared_ptr<FILE> shell_popen(const std::string& cmd,
return NULL; return NULL;
} }
if (shell_verbose()) { VLOG(3) << "Opening pipe[" << cmd << "] with mode[" << mode << "]";
LOG(INFO) << "Opening pipe[" << cmd << "] with mode[" << mode << "]";
}
std::string real_cmd = "set -o pipefail; " + cmd; std::string real_cmd = "set -o pipefail; " + cmd;
...@@ -168,43 +190,54 @@ std::shared_ptr<FILE> shell_popen(const std::string& cmd, ...@@ -168,43 +190,54 @@ std::shared_ptr<FILE> shell_popen(const std::string& cmd,
child_end = pipe_fds[0]; child_end = pipe_fds[0];
} }
int child_pid = shell_popen_fork_internal(real_cmd.c_str(), do_read, sighandler_t old_handler;
parent_end, child_end); old_handler = signal(SIGCHLD, SIG_DFL);
close(child_end);
fcntl(parent_end, F_SETFD, FD_CLOEXEC); fcntl(parent_end, F_SETFD, FD_CLOEXEC);
FILE* fp;
int child_pid = shell_popen_fork_internal(
real_cmd.c_str(), do_read, parent_end, child_end, redirect_stderr);
close(child_end);
FILE* fp = NULL;
if ((fp = fdopen(parent_end, mode.c_str())) == NULL) { if ((fp = fdopen(parent_end, mode.c_str())) == NULL) {
*err_no = -1; *err_no = -1;
signal(SIGCHLD, old_handler);
return NULL; return NULL;
} }
return {fp, [child_pid, cmd, err_no](FILE* fp) {
if (shell_verbose()) {
LOG(INFO) << "Closing pipe[" << cmd << "]";
}
if (fclose(fp) != 0) { return {fp, [cmd, child_pid, old_handler, err_no, status](FILE* fp) {
VLOG(3) << "Closing pipe[" << cmd << "]";
if (fclose(fp)) {
*err_no = -1; *err_no = -1;
} }
int wstatus = -1; int wstatus = -1;
// don't do this before parent read data from child pipe
// or when get the large data, it will hang!
waitpid(child_pid, &wstatus, 0); waitpid(child_pid, &wstatus, 0);
if (wstatus == 0 || wstatus == (128 + SIGPIPE) * 256 ||
(wstatus == -1 && errno == ECHILD)) { if (status) {
*status = wstatus;
}
if (WIFEXITED(wstatus) || wstatus == (128 + SIGPIPE) * 256) {
} else { } else {
PADDLE_ENFORCE_NE(
errno, ECHILD,
platform::errors::Fatal("Must not be ECHILD errno here!"));
*err_no = -1; *err_no = -1;
LOG(WARNING) << "status[" << wstatus << "], cmd[" << cmd << "]"
<< ", err_no[" << *err_no << "]";
}
if (wstatus == -1 && errno == ECHILD) {
// temporarily remove this warning
// LOG(WARNING) << "errno is ECHILD";
} }
signal(SIGCHLD, old_handler);
}}; }};
#endif #endif
} }
static int shell_p2open_fork_internal(const char* real_cmd, int pipein_fds[2], static int shell_p2open_fork_internal(const char* real_cmd, int pipein_fds[2],
int pipeout_fds[2]) { int pipeout_fds[2]) {
#if defined _WIN32 || defined __APPLE__ #if defined(_WIN32) || defined(__APPLE__) || defined(PADDLE_ARM)
return 0; return 0;
#else #else
int child_pid = -1; int child_pid = -1;
...@@ -243,7 +276,7 @@ static int shell_p2open_fork_internal(const char* real_cmd, int pipein_fds[2], ...@@ -243,7 +276,7 @@ static int shell_p2open_fork_internal(const char* real_cmd, int pipein_fds[2],
std::pair<std::shared_ptr<FILE>, std::shared_ptr<FILE>> shell_p2open( std::pair<std::shared_ptr<FILE>, std::shared_ptr<FILE>> shell_p2open(
const std::string& cmd) { const std::string& cmd) {
#if defined _WIN32 || defined __APPLE__ #if defined(_WIN32) || defined(__APPLE__) || defined(PADDLE_ARM)
return {}; return {};
#else #else
if (shell_verbose()) { if (shell_verbose()) {
...@@ -301,51 +334,102 @@ std::pair<std::shared_ptr<FILE>, std::shared_ptr<FILE>> shell_p2open( ...@@ -301,51 +334,102 @@ std::pair<std::shared_ptr<FILE>, std::shared_ptr<FILE>> shell_p2open(
#endif #endif
} }
std::string shell_get_command_output(const std::string& cmd, int time_out, #if defined(_WIN32) || defined(__APPLE__) || defined(PADDLE_ARM)
int sleep_inter, bool print_cmd) { #else
#if defined _WIN32 || defined __APPLE__ static int _get_err_no(int err_no, int status) {
if (err_no == 0) {
if (WIFEXITED(status)) {
return WEXITSTATUS(status);
}
return -1;
}
return err_no;
}
#endif
static int _shell_execute_cmd(const std::string& cmd, std::string* output,
int time_out, int sleep_inter,
bool redirect_stderr = false) {
#if defined(_WIN32) || defined(__APPLE__) || defined(PADDLE_ARM)
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"This function(shell_get_command_output) is not implemented under _WIN32 " "This function(shell_get_command_output) is not implemented under _WIN32 "
"or __APPLE__.")); "or __APPLE__."));
#else #else
int err_no = 0; int err_no = 0;
int status = 0;
int cmd_status = 0;
platform::Timer timer; platform::Timer timer;
do { do {
if (print_cmd) { VLOG(3) << "exec cmd:[" << cmd << "]";
LOG(INFO) << "exec cmd:[" << cmd << "]";
}
err_no = 0; err_no = 0;
std::shared_ptr<FILE> pipe = shell_popen(cmd, "r", &err_no); status = 0;
string::LineFileReader reader; *output = "";
auto pipe = shell_popen(cmd, "r", &err_no, &status, redirect_stderr);
char* buf = reader.getdelim(&*pipe, 0);
if (err_no == 0) { if (err_no == 0) {
if (buf) { // read file
return reader.get(); err_no = read_from_pipe(&*pipe, output);
if (err_no) {
LOG(WARNING) << "status[" << status << "], cmd[" << cmd << "]"
<< ", err_no[" << err_no << "]";
} }
return "";
} }
if (sleep_inter > 0) { // close file and etc.
usleep(sleep_inter); pipe = nullptr;
if (err_no) {
LOG(WARNING) << "status[" << status << "], cmd[" << cmd << "]"
<< ", err_no[" << err_no << "]";
}
cmd_status = _get_err_no(err_no, status);
// cmd run ok!
if (cmd_status == 0) {
return cmd_status;
} }
// time out
timer.Pause(); timer.Pause();
if (time_out > 0 && timer.ElapsedMS() >= time_out) { if ((time_out > 0 && timer.ElapsedMS() >= time_out) || time_out == 0) {
PADDLE_THROW(paddle::platform::errors::ExecutionTimeout( break;
"shell_get_command_output execute error errno:%d and try until "
"timeout.",
errno));
return "";
} }
timer.Resume(); timer.Resume();
pipe = nullptr; if (sleep_inter > 0) {
} while (err_no); usleep(sleep_inter * 1000);
}
} while (cmd_status);
// log when check timeout!
if (time_out != 0) {
*output += string::Sprintf(
" _shell_execute_cmd execute cmd:%s ElapsedMS:%d, err_no:%d status:%d",
cmd, timer.ElapsedMS(), err_no, cmd_status);
LOG(WARNING) << *output;
}
return cmd_status;
return "";
#endif #endif
} }
std::string shell_get_command_output(const std::string& cmd, int time_out,
int sleep_inter) {
std::string output;
_shell_execute_cmd(cmd, &output, time_out, sleep_inter);
return output;
}
std::vector<std::string> shell_execute_cmd(const std::string& cmd, int time_out,
int sleep_inter,
bool redirect_stderr) {
std::string output;
int ret =
_shell_execute_cmd(cmd, &output, time_out, sleep_inter, redirect_stderr);
return std::vector<std::string>({string::Sprintf("%d", ret), output});
}
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector>
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
...@@ -51,8 +52,10 @@ inline void shell_set_verbose(bool x) { shell_verbose_internal() = x; } ...@@ -51,8 +52,10 @@ inline void shell_set_verbose(bool x) { shell_verbose_internal() = x; }
extern std::shared_ptr<FILE> shell_fopen(const std::string& path, extern std::shared_ptr<FILE> shell_fopen(const std::string& path,
const std::string& mode); const std::string& mode);
extern std::shared_ptr<FILE> shell_popen(const std::string& cmd, std::shared_ptr<FILE> shell_popen(const std::string& cmd,
const std::string& mode, int* err_no); const std::string& mode, int* err_no,
int* status = NULL,
bool redirect_stderr = false);
extern std::pair<std::shared_ptr<FILE>, std::shared_ptr<FILE>> shell_p2open( extern std::pair<std::shared_ptr<FILE>, std::shared_ptr<FILE>> shell_p2open(
const std::string& cmd); const std::string& cmd);
...@@ -65,12 +68,17 @@ inline void shell_execute(const std::string& cmd) { ...@@ -65,12 +68,17 @@ inline void shell_execute(const std::string& cmd) {
} while (err_no == -1); } while (err_no == -1);
} }
// timeout:ms, default -1 means forever. // time_out:ms, default value:-1 means forever.
// sleep_inter:ms, default -1 means not sleep. // sleep_inter:ms, default -1 means not sleep.
extern std::string shell_get_command_output(const std::string& cmd, extern std::string shell_get_command_output(const std::string& cmd,
int time_out = -1, int time_out = 10 * 60 * 1000,
int sleep_inter = -1, int sleep_inter = 1000);
bool print_cmd = false); // time_out:ms, default -1 means forever.
// sleep_inter:ms, default -1 means not sleep.
extern std::vector<std::string> shell_execute_cmd(const std::string& cmd,
int time_out = 0,
int sleep_inter = 0,
bool redirect_stderr = false);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -135,7 +135,9 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, ...@@ -135,7 +135,9 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op) { void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op) {
// Check parameters // Check parameters
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); PADDLE_ENFORCE_EQ(graph->Has(kParamScopeAttr), true,
platform::errors::InvalidArgument(
"Graph have no attribute: kParamScopeAttr."));
auto& scope = graph->Get<Scope>(kParamScopeAttr); auto& scope = graph->Get<Scope>(kParamScopeAttr);
// Create new parameters. // Create new parameters.
...@@ -193,7 +195,10 @@ void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op) { ...@@ -193,7 +195,10 @@ void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op) {
// reshape attention_bias // reshape attention_bias
auto* attention_bias_t = auto* attention_bias_t =
scope.FindVar(param.AttentionBias)->GetMutable<LoDTensor>(); scope.FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1); PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1,
platform::errors::InvalidArgument(
"Tensor attention bias dimension size(%d) must be 1.",
attention_bias_t->dims().size()));
attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]})); attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]}));
auto* attention_scalar_bias_t = auto* attention_scalar_bias_t =
...@@ -252,7 +257,10 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, ...@@ -252,7 +257,10 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
B_forget.data<float>(), B_input.data<float>(), B_output.data<float>(), B_forget.data<float>(), B_input.data<float>(), B_output.data<float>(),
B_cell.data<float>()}; B_cell.data<float>()};
PADDLE_ENFORCE_EQ(B_forget.dims().size(), 1); PADDLE_ENFORCE_EQ(B_forget.dims().size(), 1,
platform::errors::InvalidArgument(
"Tensor B forget dimension size(%d) must be 1.",
B_forget.dims().size()));
int D = B_forget.dims()[0]; int D = B_forget.dims()[0];
out->Resize(make_ddim({1, 4 * D})); out->Resize(make_ddim({1, 4 * D}));
auto* out_data = out->mutable_data<float>(platform::CPUPlace()); auto* out_data = out->mutable_data<float>(platform::CPUPlace());
......
...@@ -119,9 +119,11 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -119,9 +119,11 @@ class CoalesceGradTensorPass : public ir::Pass {
p_g_dense_grad.insert(p_g_dense_grad.end(), group_p_g.begin(), p_g_dense_grad.insert(p_g_dense_grad.end(), group_p_g.begin(),
group_p_g.end()); group_p_g.end());
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(p_g_dense_grad.size(), num_of_p_g_dense_grad,
p_g_dense_grad.size(), num_of_p_g_dense_grad, platform::errors::InvalidArgument(
"The number of p_g_dense_grad is not consistent with before."); "The number of dense grads is not consistent with "
"previous. Previous(%d), now(%d).",
p_g_dense_grad.size(), num_of_p_g_dense_grad));
auto &pinned_var_set = auto &pinned_var_set =
graph->GetOrInit<details::PinnedVars>(details::kPinnedVars); graph->GetOrInit<details::PinnedVars>(details::kPinnedVars);
...@@ -131,8 +133,11 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -131,8 +133,11 @@ class CoalesceGradTensorPass : public ir::Pass {
} else { } else {
for (auto &sub_param_grad : group_params_grads) { for (auto &sub_param_grad : group_params_grads) {
RecordGradients(p_g_dense_grad, vars_info, &pinned_var_set); RecordGradients(p_g_dense_grad, vars_info, &pinned_var_set);
PADDLE_ENFORCE_EQ(IsUnifiedDtype(sub_param_grad, vars_info), true, PADDLE_ENFORCE_EQ(
"The data type of the same group is not consistent."); IsUnifiedDtype(sub_param_grad, vars_info), true,
platform::errors::InvalidArgument("All gradient variable in "
"kGroupParamsAndDenseGrads, must "
"have same type."));
CoalesceTensors(vars_info, sub_param_grad, &result); CoalesceTensors(vars_info, sub_param_grad, &result);
} }
} }
...@@ -145,15 +150,25 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -145,15 +150,25 @@ class CoalesceGradTensorPass : public ir::Pass {
// The Gradients should not be reused during memory optimization. // The Gradients should not be reused during memory optimization.
for (auto &p_g : sub_param_grad) { for (auto &p_g : sub_param_grad) {
auto iter = vars_info.find(p_g.second); auto iter = vars_info.find(p_g.second);
PADDLE_ENFORCE_EQ(iter != vars_info.end(), true, "%s is not found.", PADDLE_ENFORCE_EQ(iter != vars_info.end(), true,
p_g.second); platform::errors::NotFound(
PADDLE_ENFORCE_EQ(!iter->second.empty(), true); "Parameter@Grad %s is not found.", p_g.second));
PADDLE_ENFORCE_EQ(
!iter->second.empty(), true,
platform::errors::InvalidArgument(
"Parameter@Grad %s's var node is empty.", p_g.second));
for (auto it : iter->second) { for (auto it : iter->second) {
PADDLE_ENFORCE_NOT_NULL(it->Var()); PADDLE_ENFORCE_NOT_NULL(
it->Var(),
platform::errors::InvalidArgument(
"A node of Parameter@Grad %s does not hold variable.",
p_g.second));
pinned_var_set->insert(it->Var()->Name()); pinned_var_set->insert(it->Var()->Name());
} }
PADDLE_ENFORCE_EQ(IsLoDTensorType(GetTypeOfVar(vars_info, p_g.second)), PADDLE_ENFORCE_EQ(IsLoDTensorType(GetTypeOfVar(vars_info, p_g.second)),
true); true,
platform::errors::InvalidArgument(
"Parameter@Grad %s is not LoDTensor.", p_g.second));
} }
} }
...@@ -192,8 +207,10 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -192,8 +207,10 @@ class CoalesceGradTensorPass : public ir::Pass {
auto fused_grad_var_name = std::string(details::kFusedVarNamePrefix) + auto fused_grad_var_name = std::string(details::kFusedVarNamePrefix) +
"@GRAD@" + params_grads.begin()->second; "@GRAD@" + params_grads.begin()->second;
auto &fused_var_set = result->Get<details::FusedVars>(details::kFusedVars); auto &fused_var_set = result->Get<details::FusedVars>(details::kFusedVars);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_grad_var_name), 0, PADDLE_ENFORCE_EQ(
"%s is duplicate in FusedVars.", fused_grad_var_name); fused_var_set.count(fused_grad_var_name), 0,
platform::errors::AlreadyExists("Var(%s) is duplicate in FusedVars.",
fused_grad_var_name));
fused_var_set.insert(fused_grad_var_name); fused_var_set.insert(fused_grad_var_name);
result->Get<details::FusedGrads>(details::kFusedGrads) result->Get<details::FusedGrads>(details::kFusedGrads)
.emplace_back(fused_grad_var_name); .emplace_back(fused_grad_var_name);
...@@ -420,11 +437,16 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -420,11 +437,16 @@ class CoalesceGradTensorPass : public ir::Pass {
const std::unordered_map<std::string, std::vector<Node *>> &vars_info, const std::unordered_map<std::string, std::vector<Node *>> &vars_info,
const std::string &var_name) const { const std::string &var_name) const {
auto grad_iter = vars_info.find(var_name); auto grad_iter = vars_info.find(var_name);
PADDLE_ENFORCE_EQ(grad_iter != vars_info.end(), true, "%s is not found.", PADDLE_ENFORCE_EQ(
var_name); grad_iter != vars_info.end(), true,
PADDLE_ENFORCE_EQ(!grad_iter->second.empty(), true, "%s is not found.", platform::errors::NotFound("Variable %s is not found.", var_name));
var_name); PADDLE_ENFORCE_EQ(!grad_iter->second.empty(), true,
PADDLE_ENFORCE_NOT_NULL(grad_iter->second.front()->Var()); platform::errors::InvalidArgument(
"Variable %s's node is empty.", var_name));
PADDLE_ENFORCE_NOT_NULL(
grad_iter->second.front()->Var(),
platform::errors::InvalidArgument(
"A node of %s does not hold variable.", var_name));
return grad_iter->second.front()->Var(); return grad_iter->second.front()->Var();
} }
...@@ -464,7 +486,12 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -464,7 +486,12 @@ class CoalesceGradTensorPass : public ir::Pass {
params_name.emplace_back(p_g.first); params_name.emplace_back(p_g.first);
grads_name.emplace_back(p_g.second); grads_name.emplace_back(p_g.second);
auto next_dtype = GetDtypeOfVar(vars_info, p_g.second); auto next_dtype = GetDtypeOfVar(vars_info, p_g.second);
PADDLE_ENFORCE_EQ(next_dtype, dtype); PADDLE_ENFORCE_EQ(
next_dtype, dtype,
platform::errors::InvalidArgument(
"All Parameter@Grad should have same dtype, but "
"there are two different type: %s, %s.",
DataTypeToString(next_dtype), DataTypeToString(dtype)));
} }
result->Get<details::ProgramDescs>(details::kProgramDescs).emplace_back(); result->Get<details::ProgramDescs>(details::kProgramDescs).emplace_back();
......
...@@ -50,7 +50,12 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight, ...@@ -50,7 +50,12 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight,
Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>; Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
// Re-compute bias of conv2d from AffineChannel // Re-compute bias of conv2d from AffineChannel
PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(), ac_bias_tensor.dims()); PADDLE_ENFORCE_EQ(
eltwise_y_in_tensor->dims(), ac_bias_tensor.dims(),
platform::errors::InvalidArgument(
"Tensor elementwise y(%d) and activation bias(%d) must have same "
"dimension.",
eltwise_y_in_tensor->dims().size(), ac_bias_tensor.dims().size()));
auto* scale_tensor = scope->FindVar(ac_scale.Name())->GetMutable<LoDTensor>(); auto* scale_tensor = scope->FindVar(ac_scale.Name())->GetMutable<LoDTensor>();
...@@ -78,11 +83,13 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight, ...@@ -78,11 +83,13 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight,
} }
void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = auto* conv_input =
...@@ -152,11 +159,13 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -152,11 +159,13 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
} }
void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = auto* conv_input =
......
...@@ -61,7 +61,12 @@ void recompute_bias_and_weights(const Scope* scope, ...@@ -61,7 +61,12 @@ void recompute_bias_and_weights(const Scope* scope,
Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>; Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
// Re-compute bias of conv2d from BN // Re-compute bias of conv2d from BN
PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(), bn_bias_tensor.dims()); PADDLE_ENFORCE_EQ(
eltwise_y_in_tensor->dims(), bn_bias_tensor.dims(),
platform::errors::InvalidArgument("Tensor elementwise y(%d) and batch "
"norm bias(%d) must have same dims.",
eltwise_y_in_tensor->dims().size(),
bn_bias_tensor.dims().size()));
auto* scale_tensor = scope->FindVar(bn_scale.Name())->GetMutable<LoDTensor>(); auto* scale_tensor = scope->FindVar(bn_scale.Name())->GetMutable<LoDTensor>();
auto* variance_tensor = auto* variance_tensor =
...@@ -116,11 +121,13 @@ void recompute_bias_and_weights(const Scope* scope, ...@@ -116,11 +121,13 @@ void recompute_bias_and_weights(const Scope* scope,
} }
void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = auto* conv_input =
...@@ -186,11 +193,18 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -186,11 +193,18 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
if (has_bias && conv->Op()->Input("Bias").size() > 0) { if (has_bias && conv->Op()->Input("Bias").size() > 0) {
// reuse existing conv bias node // reuse existing conv bias node
auto conv_bias_names = conv->Op()->Input("Bias"); auto conv_bias_names = conv->Op()->Input("Bias");
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1UL); PADDLE_ENFORCE_EQ(
conv_bias_names.size(), 1UL,
platform::errors::InvalidArgument("Find input var Bais error."));
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]); auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>(); auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), PADDLE_ENFORCE_EQ(
eltwise_y_in_tensor->dims()); conv_bias_tensor->dims(), eltwise_y_in_tensor->dims(),
platform::errors::InvalidArgument(
"Tensor convolution bias(%d) and elementwise y(%d) "
"must have same dims.",
conv_bias_tensor->dims().size(),
eltwise_y_in_tensor->dims().size()));
auto eigen_conv_bias = EigenVector<float>::From(*conv_bias_tensor); auto eigen_conv_bias = EigenVector<float>::From(*conv_bias_tensor);
eigen_conv_bias += EigenVector<float>::From(*eltwise_y_in_tensor); eigen_conv_bias += EigenVector<float>::From(*eltwise_y_in_tensor);
...@@ -236,11 +250,13 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -236,11 +250,13 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
} }
void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = auto* conv_input =
......
...@@ -71,8 +71,16 @@ void TestMain(const std::string& conv_type) { ...@@ -71,8 +71,16 @@ void TestMain(const std::string& conv_type) {
int num_bn_nodes_after = GetNumOpNodes(graph, "batch_norm"); int num_bn_nodes_after = GetNumOpNodes(graph, "batch_norm");
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_bn_nodes_before, 1); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(num_bn_nodes_after, 0); num_bn_nodes_before, 1,
platform::errors::InvalidArgument(
"Before conv_bn_fuse_pass, number of batch norm op(%d) must be 1.",
num_bn_nodes_before));
PADDLE_ENFORCE_EQ(
num_bn_nodes_after, 0,
platform::errors::InvalidArgument(
"After conv_bn_fuse_pass, number of batch norm op(%d) must be 0.",
num_bn_nodes_after));
} }
TEST(ConvBNFusePass, conv2d) { TestMain("conv"); } TEST(ConvBNFusePass, conv2d) { TestMain("conv"); }
......
...@@ -91,7 +91,9 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -91,7 +91,9 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
auto* new_conv_op = graph->CreateOpNode(&new_op_desc); auto* new_conv_op = graph->CreateOpNode(&new_op_desc);
// Link inputs and outputs. // Link inputs and outputs.
PADDLE_ENFORCE(subgraph.count(x)); PADDLE_ENFORCE_NE(
subgraph.count(x), 0,
platform::errors::NotFound("Detector did not find input x of conv2d."));
auto* conv_in_node = subgraph.at(x); auto* conv_in_node = subgraph.at(x);
IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input
......
...@@ -78,7 +78,9 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -78,7 +78,9 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
auto* new_conv_op = graph->CreateOpNode(&new_op_desc); auto* new_conv_op = graph->CreateOpNode(&new_op_desc);
// Link inputs and outputs. // Link inputs and outputs.
PADDLE_ENFORCE(subgraph.count(x)); PADDLE_ENFORCE_NE(
subgraph.count(x), 0,
platform::errors::NotFound("Detector did not find input x of conv2d."));
auto* conv_in_node = subgraph.at(x); auto* conv_in_node = subgraph.at(x);
IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input
......
...@@ -66,7 +66,9 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -66,7 +66,9 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
auto* new_conv_op = graph->CreateOpNode(&new_op_desc); auto* new_conv_op = graph->CreateOpNode(&new_op_desc);
// Link inputs and outputs. // Link inputs and outputs.
PADDLE_ENFORCE(subgraph.count(x)); PADDLE_ENFORCE_NE(
subgraph.count(x), 0,
platform::errors::NotFound("Detector did not find input x of conv2d."));
auto* conv_in_node = subgraph.at(x); auto* conv_in_node = subgraph.at(x);
IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input
......
...@@ -64,17 +64,23 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -64,17 +64,23 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef SET_IN #undef SET_IN
// Multiply embeddings with Weights // Multiply embeddings with Weights
PADDLE_ENFORCE(scope); PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
const std::string& embeddings = patterns::UniqueKey("Embeddings"); const std::string& embeddings = patterns::UniqueKey("Embeddings");
auto* embeddings_var = scope->Var(embeddings); auto* embeddings_var = scope->Var(embeddings);
PADDLE_ENFORCE(embeddings_var); PADDLE_ENFORCE_NOT_NULL(
embeddings_var,
platform::errors::InvalidArgument(
"Embeddings variable's pointer cannot be nullptr."));
auto* embeddings_tensor = auto* embeddings_tensor =
embeddings_var->GetMutable<framework::LoDTensor>(); embeddings_var->GetMutable<framework::LoDTensor>();
// Get WeightX size: [single_embedding, fc_size] // Get WeightX size: [single_embedding, fc_size]
// and embedding size: [dict_size, single_embedding] // and embedding size: [dict_size, single_embedding]
// and create new size of embeddings eg. [dict_size , hidden_size] // and create new size of embeddings eg. [dict_size , hidden_size]
auto* embedding_var = scope->FindVar(W->Name()); auto* embedding_var = scope->FindVar(W->Name());
PADDLE_ENFORCE(embedding_var); PADDLE_ENFORCE_NOT_NULL(
embedding_var, platform::errors::InvalidArgument(
"Embedding variable's pointer cannot be nullptr."));
const auto& embedding_tensor = embedding_var->Get<framework::LoDTensor>(); const auto& embedding_tensor = embedding_var->Get<framework::LoDTensor>();
const auto& weightx_tensor = const auto& weightx_tensor =
...@@ -90,7 +96,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -90,7 +96,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
// Adding biases to GEMM result to be // Adding biases to GEMM result to be
auto* lstm_bias_var = scope->FindVar(bias->Name()); auto* lstm_bias_var = scope->FindVar(bias->Name());
PADDLE_ENFORCE(lstm_bias_var); PADDLE_ENFORCE_NOT_NULL(lstm_bias_var,
platform::errors::InvalidArgument(
"Lstm bias var ptr cannot be nullptr."));
const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>(); const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>();
auto alpha = 1.0f; auto alpha = 1.0f;
......
...@@ -56,8 +56,17 @@ TEST(FCElementwiseLayerNormFusePass, basic) { ...@@ -56,8 +56,17 @@ TEST(FCElementwiseLayerNormFusePass, basic) {
GetNumOpNodes(graph, "fused_fc_elementwise_layernorm"); GetNumOpNodes(graph, "fused_fc_elementwise_layernorm");
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 6); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1); num_nodes_before, num_nodes_after + 6,
platform::errors::InvalidArgument(
"After pass, the number of nodes should be reduced by 6, but the "
"number before pass is %d, after pass is %d.",
num_nodes_before, num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1,
platform::errors::InvalidArgument(
"After pass, the number of nodes of type "
"'fused_fc_elementwise_layernorm' should be 1, not %d.",
num_fused_nodes_after));
} }
} // namespace ir } // namespace ir
......
...@@ -25,7 +25,8 @@ namespace framework { ...@@ -25,7 +25,8 @@ namespace framework {
namespace ir { namespace ir {
void FCFusePass::ApplyImpl(ir::Graph* graph) const { void FCFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("fc_fuse", graph); FusePassBase::Init("fc_fuse", graph);
int found_fc_count = 0; int found_fc_count = 0;
......
...@@ -79,9 +79,17 @@ TEST(FCFusePass, basic) { ...@@ -79,9 +79,17 @@ TEST(FCFusePass, basic) {
int num_fc_nodes_after = GetNumOpNodes(graph, "fc"); int num_fc_nodes_after = GetNumOpNodes(graph, "fc");
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 6); PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 6,
PADDLE_ENFORCE_EQ(num_fc_nodes_after, 2); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(num_mul_nodes_before, num_fc_nodes_after); "num_nodes_before=%d, num_nodes_after=%d.",
num_nodes_before, num_nodes_after));
PADDLE_ENFORCE_EQ(num_fc_nodes_after, 2,
platform::errors::InvalidArgument("num_fc_nodes_after=%d.",
num_fc_nodes_after));
PADDLE_ENFORCE_EQ(num_mul_nodes_before, num_fc_nodes_after,
platform::errors::InvalidArgument(
"num_mul_nodes_before=%d, num_fc_nodes_after=%d.",
num_mul_nodes_before, num_fc_nodes_after));
} }
} // namespace ir } // namespace ir
......
...@@ -68,18 +68,27 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -68,18 +68,27 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef SET_IMTERMEDIATE_OUT #undef SET_IMTERMEDIATE_OUT
auto* op = graph->CreateOpNode(&op_desc); auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); PADDLE_ENFORCE_EQ(graph->Has(kParamScopeAttr), true,
platform::errors::InvalidArgument(
"Graph have no attr kParamScopeAttr."));
auto& scope = graph->Get<Scope>(kParamScopeAttr); auto& scope = graph->Get<Scope>(kParamScopeAttr);
if (with_fc_bias) { if (with_fc_bias) {
// Fusion GRU bias = fcbias + grubias // Fusion GRU bias = fcbias + grubias
auto* fusion_bias_var = scope.Var(NEW_NAME(bias) + bias->Name()); auto* fusion_bias_var = scope.Var(NEW_NAME(bias) + bias->Name());
auto* out_bias_tensor = auto* out_bias_tensor =
fusion_bias_var->GetMutable<framework::LoDTensor>(); fusion_bias_var->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE(fusion_bias_var); PADDLE_ENFORCE_NOT_NULL(
fusion_bias_var,
platform::errors::InvalidArgument(
"Fusion bias variable's pointer cannot be nullptr."));
auto* gru_bias_var = scope.FindVar(bias->Name()); auto* gru_bias_var = scope.FindVar(bias->Name());
auto* fc_bias_var = scope.FindVar(fc_bias->Name()); auto* fc_bias_var = scope.FindVar(fc_bias->Name());
PADDLE_ENFORCE(gru_bias_var); PADDLE_ENFORCE_NOT_NULL(gru_bias_var,
PADDLE_ENFORCE(fc_bias_var); platform::errors::InvalidArgument(
"Gru bias var ptr cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(fc_bias_var,
platform::errors::InvalidArgument(
"Fc bias var ptr cannot be nullptr."));
const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>(); const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>();
const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>(); const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>();
// new bias = fc bias + gru bias // new bias = fc bias + gru bias
......
...@@ -52,13 +52,17 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -52,13 +52,17 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
#undef SET_IN #undef SET_IN
if (with_fc_bias) { if (with_fc_bias) {
// Add FC-bias with LSTM-bias and create a new weight // Add FC-bias with LSTM-bias and create a new weight
PADDLE_ENFORCE(scope); PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
const std::string& new_bias_var = patterns::UniqueKey("NewBias"); const std::string& new_bias_var = patterns::UniqueKey("NewBias");
auto* bias_var = scope->Var(new_bias_var); auto* bias_var = scope->Var(new_bias_var);
PADDLE_ENFORCE(bias_var); PADDLE_ENFORCE_NOT_NULL(bias_var, platform::errors::InvalidArgument(
"Bias var ptr cannot be nullptr."));
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>(); auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
auto* lstm_bias_var = scope->FindVar(bias->Name()); auto* lstm_bias_var = scope->FindVar(bias->Name());
PADDLE_ENFORCE(lstm_bias_var); PADDLE_ENFORCE_NOT_NULL(lstm_bias_var,
platform::errors::InvalidArgument(
"Lstm bias var ptr cannot be nullptr."));
const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>(); const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>();
bias_tensor->Resize(lstm_bias_tensor.dims()); bias_tensor->Resize(lstm_bias_tensor.dims());
......
...@@ -50,18 +50,25 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -50,18 +50,25 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
fused_scale2->inputs.end()); fused_scale2->inputs.end());
for (auto &out_node : fused_scale1->outputs) { for (auto &out_node : fused_scale1->outputs) {
if (fused_scale2_in_nodes.count(out_node)) { if (fused_scale2_in_nodes.count(out_node)) {
PADDLE_ENFORCE(out_node->IsCtrlVar(), PADDLE_ENFORCE_EQ(out_node->IsCtrlVar(), true,
"The dependency var only should be ctrl var."); platform::errors::PreconditionNotMet(
"In adam op pass, the dependency var(%s) only "
"should be ctrl var.",
out_node->Name()));
not_need_ctrl_var_nodes.insert(out_node); not_need_ctrl_var_nodes.insert(out_node);
} }
} }
for (auto &node : not_need_ctrl_var_nodes) { for (auto &node : not_need_ctrl_var_nodes) {
// remove this node from the input op node. // remove this node from the input op node.
PADDLE_ENFORCE(!node->inputs.empty(), PADDLE_ENFORCE_EQ(
"The input should not be empty here."); node->inputs.empty(), false,
platform::errors::PreconditionNotMet(
"Node(%s)'s input should not be empty here.", node->Name()));
auto op_node = node->inputs.front(); auto op_node = node->inputs.front();
PADDLE_ENFORCE(op_node->IsOp()); PADDLE_ENFORCE_EQ(op_node->IsOp(), true,
platform::errors::PreconditionNotMet(
"Node(%s) should be an OP node.", op_node->Name()));
op_node->outputs.erase( op_node->outputs.erase(
remove_if( remove_if(
op_node->outputs.begin(), op_node->outputs.end(), op_node->outputs.begin(), op_node->outputs.end(),
...@@ -85,7 +92,9 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -85,7 +92,9 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
const std::unordered_map<std::string, std::vector<std::string>> &vars_set, const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name, const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const { const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const {
PADDLE_ENFORCE_GT(adam_ops.size(), static_cast<size_t>(0)); PADDLE_ENFORCE_GT(
adam_ops.size(), static_cast<size_t>(0),
platform::errors::InvalidArgument("No adam op in the graph."));
// Check attributions // Check attributions
// NOTE: If new attribution is added, the following code maybe need change. // NOTE: If new attribution is added, the following code maybe need change.
...@@ -102,22 +111,58 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -102,22 +111,58 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
int64_t, adam_ops[0]->Op()->GetAttr("min_row_size_to_use_multithread")); int64_t, adam_ops[0]->Op()->GetAttr("min_row_size_to_use_multithread"));
for (auto &adam_op : adam_ops) { for (auto &adam_op : adam_ops) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
beta1, BOOST_GET_CONST(float, adam_op->Op()->GetAttr("beta1"))); beta1, BOOST_GET_CONST(float, adam_op->Op()->GetAttr("beta1")),
platform::errors::PreconditionNotMet(
"All adam Op's attr(beta1) must be same, but there are two "
"different "
"value: %f, %f.",
beta1, BOOST_GET_CONST(float, adam_op->Op()->GetAttr("beta1"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
beta2, BOOST_GET_CONST(float, adam_op->Op()->GetAttr("beta2"))); beta2, BOOST_GET_CONST(float, adam_op->Op()->GetAttr("beta2")),
platform::errors::PreconditionNotMet(
"All adam Op's attr(beta2) must be same, but there are two "
"different "
"value: %f, %f.",
beta2, BOOST_GET_CONST(float, adam_op->Op()->GetAttr("beta2"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
epsilon, BOOST_GET_CONST(float, adam_op->Op()->GetAttr("epsilon"))); epsilon, BOOST_GET_CONST(float, adam_op->Op()->GetAttr("epsilon")),
platform::errors::PreconditionNotMet(
"All adam Op's attr(epsilon) must be same, but there are two "
"different "
"value: %f, %f.",
epsilon,
BOOST_GET_CONST(float, adam_op->Op()->GetAttr("epsilon"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
lazy_mode, lazy_mode, BOOST_GET_CONST(bool, adam_op->Op()->GetAttr("lazy_mode")),
BOOST_GET_CONST(bool, adam_op->Op()->GetAttr("lazy_mode"))); platform::errors::PreconditionNotMet(
"All adam Op's attr(lazy_mode) must be same, but there are two "
"different "
"value: %d, %d.",
lazy_mode,
BOOST_GET_CONST(bool, adam_op->Op()->GetAttr("lazy_mode"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
min_row_size_to_use_multithread, min_row_size_to_use_multithread,
BOOST_GET_CONST(int64_t, adam_op->Op()->GetAttr( BOOST_GET_CONST(int64_t, adam_op->Op()->GetAttr(
"min_row_size_to_use_multithread"))); "min_row_size_to_use_multithread")),
platform::errors::PreconditionNotMet(
"All adam Op's attr(min_row_size_to_use_multithread) must be "
"same, but there are two different value: %I64, %I64.",
min_row_size_to_use_multithread,
BOOST_GET_CONST(
int64_t,
adam_op->Op()->GetAttr("min_row_size_to_use_multithread"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op_role, op_role,
BOOST_GET_CONST(int, adam_op->Op()->GetAttr( BOOST_GET_CONST(int, adam_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName()))); OpProtoAndCheckerMaker::OpRoleAttrName())),
platform::errors::PreconditionNotMet(
"All adam Op's attr(op_role) must be same, but there are two "
"different "
"value: %d, %d.",
op_role,
BOOST_GET_CONST(int,
adam_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName()))));
} }
// NOTE: fused_var is only exist in scope, so the graph doesn't have // NOTE: fused_var is only exist in scope, so the graph doesn't have
...@@ -154,7 +199,10 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -154,7 +199,10 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
const std::string &fused_var_name, const std::string &fused_var_name,
const std::vector<ir::Node *> &adam_ops, const std::vector<ir::Node *> &adam_ops,
ir::Graph *graph) const { ir::Graph *graph) const {
PADDLE_ENFORCE_EQ(beta_name.size(), adam_ops.size()); PADDLE_ENFORCE_EQ(beta_name.size(), adam_ops.size(),
platform::errors::InvalidArgument(
"Beta name size(%d) must equal to adam op size(%d).",
beta_name.size(), adam_ops.size()));
const std::string scale_op_name = "scale"; const std::string scale_op_name = "scale";
// Get the scale_ops of dealing the adam's beta var. // Get the scale_ops of dealing the adam's beta var.
...@@ -168,7 +216,9 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -168,7 +216,9 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
return var_node->Var() && return var_node->Var() &&
var_node->Var()->Name() == beta_1_pow_name; var_node->Var()->Name() == beta_1_pow_name;
}); });
PADDLE_ENFORCE(beta_pow_iter != adam_ops[i]->inputs.end()); PADDLE_ENFORCE_NE(beta_pow_iter, adam_ops[i]->inputs.end(),
platform::errors::NotFound(
"Can not find %s in adam ops.", beta_1_pow_name));
auto beta_pow_node = *beta_pow_iter; auto beta_pow_node = *beta_pow_iter;
auto scale_op_iter = std::find_if( auto scale_op_iter = std::find_if(
...@@ -176,11 +226,18 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -176,11 +226,18 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
[&scale_op_name](ir::Node *op_node) -> bool { [&scale_op_name](ir::Node *op_node) -> bool {
return op_node->Op() && op_node->Op()->Type() == scale_op_name; return op_node->Op() && op_node->Op()->Type() == scale_op_name;
}); });
PADDLE_ENFORCE(scale_op_iter != beta_pow_node->outputs.end()); PADDLE_ENFORCE_NE(
scale_op_iter, beta_pow_node->outputs.end(),
platform::errors::NotFound("Can not find %s in beta pow node.",
scale_op_name));
scale_ops.emplace_back(*scale_op_iter); scale_ops.emplace_back(*scale_op_iter);
} }
PADDLE_ENFORCE_EQ(scale_ops.size(), beta_name.size()); PADDLE_ENFORCE_EQ(
scale_ops.size(), beta_name.size(),
platform::errors::PreconditionNotMet(
"Beta name size(%d) must equal to scale ops size(%d).",
beta_name.size(), scale_ops.size()));
VLOG(6) << "The number of scale op is " << scale_ops.size() << "."; VLOG(6) << "The number of scale op is " << scale_ops.size() << ".";
// Check attributions // Check attributions
// NOTE: If new attribution is added, the following code maybe need change. // NOTE: If new attribution is added, the following code maybe need change.
...@@ -193,16 +250,40 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -193,16 +250,40 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
BOOST_GET_CONST(bool, scale_ops[0]->Op()->GetAttr("bias_after_scale")); BOOST_GET_CONST(bool, scale_ops[0]->Op()->GetAttr("bias_after_scale"));
for (auto &scale_op : scale_ops) { for (auto &scale_op : scale_ops) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale, BOOST_GET_CONST(float, scale_op->Op()->GetAttr("scale"))); scale, BOOST_GET_CONST(float, scale_op->Op()->GetAttr("scale")),
platform::errors::PreconditionNotMet(
"All scale Op's attr(scale) must be same, but there are two "
"different "
"value: %f, %f.",
scale, BOOST_GET_CONST(float, scale_op->Op()->GetAttr("scale"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
bias, BOOST_GET_CONST(float, scale_op->Op()->GetAttr("bias"))); bias, BOOST_GET_CONST(float, scale_op->Op()->GetAttr("bias")),
platform::errors::PreconditionNotMet(
"All scale Op's attr(bias) must be same, but there are two "
"different "
"value: %f, %f.",
bias, BOOST_GET_CONST(float, scale_op->Op()->GetAttr("bias"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
bias_after_scale, bias_after_scale,
BOOST_GET_CONST(bool, scale_op->Op()->GetAttr("bias_after_scale"))); BOOST_GET_CONST(bool, scale_op->Op()->GetAttr("bias_after_scale")),
platform::errors::PreconditionNotMet(
"All scale Op's attr(bias_after_scale) must be same, but there "
"are two different value: %d, %d.",
bias_after_scale,
BOOST_GET_CONST(bool,
scale_op->Op()->GetAttr("bias_after_scale"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op_role, op_role,
BOOST_GET_CONST(int, scale_op->Op()->GetAttr( BOOST_GET_CONST(int, scale_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName()))); OpProtoAndCheckerMaker::OpRoleAttrName())),
platform::errors::PreconditionNotMet(
"All scale Op's attr(op_role) must be same, but there are two "
"different "
"value: %d, %d.",
op_role,
BOOST_GET_CONST(int,
scale_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName()))));
} }
// NOTE: fused_var is only exist in scope, so the graph doesn't have // NOTE: fused_var is only exist in scope, so the graph doesn't have
......
...@@ -37,7 +37,9 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass { ...@@ -37,7 +37,9 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
const std::unordered_map<std::string, std::vector<std::string>> &vars_set, const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name, const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &momentum_ops, ir::Graph *graph) const { const std::vector<ir::Node *> &momentum_ops, ir::Graph *graph) const {
PADDLE_ENFORCE_GT(momentum_ops.size(), static_cast<size_t>(0)); PADDLE_ENFORCE_GT(
momentum_ops.size(), static_cast<size_t>(0),
platform::errors::InvalidArgument("Momentum ops must not be empyt."));
// Check attributions // Check attributions
// NOTE: If new attribution is added, the following code maybe need change. // NOTE: If new attribution is added, the following code maybe need change.
...@@ -50,14 +52,32 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass { ...@@ -50,14 +52,32 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
for (auto &momentum_op : momentum_ops) { for (auto &momentum_op : momentum_ops) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
mu, BOOST_GET_CONST(float, momentum_op->Op()->GetAttr("mu"))); mu, BOOST_GET_CONST(float, momentum_op->Op()->GetAttr("mu")),
platform::errors::InvalidArgument(
"All momentum Op's attr(mu) must be same, but there are two "
"different "
"value: %f, %f.",
mu, BOOST_GET_CONST(float, momentum_op->Op()->GetAttr("mu"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
use_nesterov, use_nesterov,
BOOST_GET_CONST(bool, momentum_op->Op()->GetAttr("use_nesterov"))); BOOST_GET_CONST(bool, momentum_op->Op()->GetAttr("use_nesterov")),
platform::errors::InvalidArgument(
"All momentum Op's attr(use_nesterov) must be same, but there "
"are two different value: %d, %d.",
use_nesterov, BOOST_GET_CONST(bool, momentum_op->Op()->GetAttr(
"use_nesterov"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op_role, op_role,
BOOST_GET_CONST(int, momentum_op->Op()->GetAttr( BOOST_GET_CONST(int, momentum_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName()))); OpProtoAndCheckerMaker::OpRoleAttrName())),
platform::errors::InvalidArgument(
"All momentum Op's attr(op_role) must be same, but there are two "
"different "
"value: %d, %d.",
op_role,
BOOST_GET_CONST(int,
momentum_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName()))));
} }
// NOTE: fused_var is only exist in scope, so the graph doesn't have // NOTE: fused_var is only exist in scope, so the graph doesn't have
......
...@@ -41,10 +41,12 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -41,10 +41,12 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
for (auto &node : topo_nodes) { for (auto &node : topo_nodes) {
if (node->Op()->Type() == fuse_op_type) { if (node->Op()->Type() == fuse_op_type) {
auto grad_name = node->Op()->Input(kGrad); auto grad_name = node->Op()->Input(kGrad);
PADDLE_ENFORCE_EQ(grad_name.size(), static_cast<size_t>(1), PADDLE_ENFORCE_EQ(
"The %s operator has multiple gradient input. Expected " grad_name.size(), static_cast<size_t>(1),
"it to only have one gradient input.", platform::errors::InvalidArgument(
fuse_op_type); "The %s operator has multiple gradient input. Expected "
"it to only have one gradient input.",
fuse_op_type));
if (IsLoDTensorType(GetTypeOfVar(vars_info, grad_name[0]))) { if (IsLoDTensorType(GetTypeOfVar(vars_info, grad_name[0]))) {
opt_nodes.emplace_back(node); opt_nodes.emplace_back(node);
} }
...@@ -96,7 +98,8 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -96,7 +98,8 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
VLOG(6) << var_name << ": " << fused_var_name; VLOG(6) << var_name << ": " << fused_var_name;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
fused_var_set.count(fused_var_name), 0, fused_var_set.count(fused_var_name), 0,
platform::errors::AlreadyExists("The fused variable already exists.")); platform::errors::AlreadyExists(
"The fused variable(%s) already exists.", fused_var_name));
fused_var_set.insert(fused_var_name); fused_var_set.insert(fused_var_name);
fused_vars_name.emplace(var_name, fused_var_name); fused_vars_name.emplace(var_name, fused_var_name);
} }
...@@ -110,7 +113,10 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -110,7 +113,10 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
result.Get<details::ParamsAndGrads>(details::kParamsAndDenseGrads); result.Get<details::ParamsAndGrads>(details::kParamsAndDenseGrads);
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
params_and_dense_grads.size(), aux_var_map.at(kGrad).size(), params_and_dense_grads.size(), aux_var_map.at(kGrad).size(),
"The number of dense gradients should be little than optimizer ops."); platform::errors::InvalidArgument(
"The number of dense gradients(%d) should be "
"little than optimizer ops(%d).",
params_and_dense_grads.size(), aux_var_map.at(kGrad).size()));
std::unordered_set<std::string> opt_grad_set(aux_var_map.at(kGrad).size()); std::unordered_set<std::string> opt_grad_set(aux_var_map.at(kGrad).size());
for (auto &p_g : params_and_dense_grads) { for (auto &p_g : params_and_dense_grads) {
...@@ -130,13 +136,14 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -130,13 +136,14 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
// some gradient's name maybe changed. // some gradient's name maybe changed.
if (new_grad_idx.size() == 0) { if (new_grad_idx.size() == 0) {
if (!result.Has(details::kFusedGrads)) { if (!result.Has(details::kFusedGrads)) {
PADDLE_THROW( PADDLE_THROW(platform::errors::PreconditionNotMet(
"The coalesce_grad_tensor_pass should " "The coalesce_grad_tensor_pass should "
"be called before this pass."); "be called before this pass."));
} }
auto &fused_grad = result.Get<details::FusedGrads>(details::kFusedGrads); auto &fused_grad = result.Get<details::FusedGrads>(details::kFusedGrads);
PADDLE_ENFORCE_NE(fused_grad.size(), 0, PADDLE_ENFORCE_NE(fused_grad.size(), 0,
"The fused gradient should not be empty."); platform::errors::NotFound(
"The fused gradient should not be empty."));
if (fused_grad.size() > 1) { if (fused_grad.size() > 1) {
// Note(chenweihang): Because the dtype of those gradients is not // Note(chenweihang): Because the dtype of those gradients is not
// unified,so the number of fused gradients is more than one, // unified,so the number of fused gradients is more than one,
...@@ -146,8 +153,9 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -146,8 +153,9 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
auto &fused_vars = result.Get<details::FusedVars>(details::kFusedVars); auto &fused_vars = result.Get<details::FusedVars>(details::kFusedVars);
auto iter = auto iter =
std::find(fused_vars.begin(), fused_vars.end(), fused_grad.front()); std::find(fused_vars.begin(), fused_vars.end(), fused_grad.front());
PADDLE_ENFORCE_EQ(iter != fused_vars.end(), true, PADDLE_ENFORCE_EQ(
"Not found the fused gradient variable."); iter != fused_vars.end(), true,
platform::errors::NotFound("Not found the fused gradient variable."));
fused_vars_name[kGrad] = fused_grad.front(); fused_vars_name[kGrad] = fused_grad.front();
// Sort the parameters and auxiliary variables according // Sort the parameters and auxiliary variables according
...@@ -334,16 +342,24 @@ void FuseOptimizerOpPass::FuseGradientsToContinuousSpace( ...@@ -334,16 +342,24 @@ void FuseOptimizerOpPass::FuseGradientsToContinuousSpace(
// The Gradients should not be reused during memory optimization. // The Gradients should not be reused during memory optimization.
for (auto &grad_var_name : grads) { for (auto &grad_var_name : grads) {
auto iter = vars_info.find(grad_var_name); auto iter = vars_info.find(grad_var_name);
PADDLE_ENFORCE_EQ(iter != vars_info.end(), true, PADDLE_ENFORCE_EQ(
"The gradient variable %s is not found.", grad_var_name); iter != vars_info.end(), true,
PADDLE_ENFORCE_EQ(!iter->second.empty(), true, platform::errors::NotFound("The gradient variable %s is not found.",
"The gradient var node %s is not found.", grad_var_name); grad_var_name));
PADDLE_ENFORCE_NOT_NULL(iter->second.front()->Var(), PADDLE_ENFORCE_EQ(
"The gradient var node is null."); !iter->second.empty(), true,
platform::errors::NotFound("The gradient var node %s is not found.",
grad_var_name));
PADDLE_ENFORCE_NOT_NULL(
iter->second.front()->Var(),
platform::errors::InvalidArgument("The gradient var(%s) node is null.",
grad_var_name));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
IsLoDTensorType(iter->second.front()->Var()->GetType()), true, IsLoDTensorType(iter->second.front()->Var()->GetType()), true,
"Currently the gradient type only should be LoDTensor when " platform::errors::InvalidArgument(
"fusing optimizer ops."); "Currently the gradient(%s) type only should be LoDTensor when "
"fusing optimizer ops.",
grad_var_name));
for (auto var : iter->second) { for (auto var : iter->second) {
pinned_var_set.insert(var->Var()->Name()); pinned_var_set.insert(var->Var()->Name());
} }
...@@ -382,11 +398,14 @@ const VarDesc *FuseOptimizerOpPass::GetVarDescFromVarsInfo( ...@@ -382,11 +398,14 @@ const VarDesc *FuseOptimizerOpPass::GetVarDescFromVarsInfo(
const std::string &var_name) const { const std::string &var_name) const {
auto grad_iter = vars_info.find(var_name); auto grad_iter = vars_info.find(var_name);
PADDLE_ENFORCE_EQ(grad_iter != vars_info.end(), true, PADDLE_ENFORCE_EQ(grad_iter != vars_info.end(), true,
"The gradient variable %s is not found.", var_name); platform::errors::NotFound(
"The gradient variable %s is not found.", var_name));
PADDLE_ENFORCE_EQ(!grad_iter->second.empty(), true, PADDLE_ENFORCE_EQ(!grad_iter->second.empty(), true,
"The gradient var node %s is not found.", var_name); platform::errors::NotFound(
"The gradient var node %s is not found.", var_name));
PADDLE_ENFORCE_NOT_NULL(grad_iter->second.front()->Var(), PADDLE_ENFORCE_NOT_NULL(grad_iter->second.front()->Var(),
"The gradient var node is null."); platform::errors::InvalidArgument(
"The gradient var(%s) node is null.", var_name));
return grad_iter->second.front()->Var(); return grad_iter->second.front()->Var();
} }
...@@ -428,8 +447,9 @@ void FuseOptimizerOpPass::SortParametersAndAuxVars( ...@@ -428,8 +447,9 @@ void FuseOptimizerOpPass::SortParametersAndAuxVars(
const std::vector<std::pair<std::string, std::string>> &params_grads, const std::vector<std::pair<std::string, std::string>> &params_grads,
std::unordered_map<std::string, std::vector<std::string>> *aux_var_map, std::unordered_map<std::string, std::vector<std::string>> *aux_var_map,
std::vector<ir::Node *> *ops) const { std::vector<ir::Node *> *ops) const {
PADDLE_ENFORCE_NE(aux_var_map->count(kGrad), static_cast<size_t>(0), PADDLE_ENFORCE_NE(
"The gradient variable doesn‘t exist."); aux_var_map->count(kGrad), static_cast<size_t>(0),
platform::errors::NotFound("The gradient variable doesn‘t exist."));
auto &grad_vec = aux_var_map->at(kGrad); auto &grad_vec = aux_var_map->at(kGrad);
std::vector<size_t> grad_sort_idx; std::vector<size_t> grad_sort_idx;
...@@ -437,8 +457,10 @@ void FuseOptimizerOpPass::SortParametersAndAuxVars( ...@@ -437,8 +457,10 @@ void FuseOptimizerOpPass::SortParametersAndAuxVars(
for (auto &p_g : params_grads) { for (auto &p_g : params_grads) {
auto iter = std::find(grad_vec.begin(), grad_vec.end(), p_g.second); auto iter = std::find(grad_vec.begin(), grad_vec.end(), p_g.second);
PADDLE_ENFORCE_EQ(iter != grad_vec.end(), true, PADDLE_ENFORCE_EQ(
"%s is not found in gradient vector", p_g.second); iter != grad_vec.end(), true,
platform::errors::NotFound(
"Parameter@Grad(%s) is not found in gradient vector.", p_g.second));
auto idx = std::distance(grad_vec.begin(), iter); auto idx = std::distance(grad_vec.begin(), iter);
grad_sort_idx.emplace_back(idx); grad_sort_idx.emplace_back(idx);
} }
...@@ -477,9 +499,10 @@ void FuseOptimizerOpPass::GetFusingVarNamesMap( ...@@ -477,9 +499,10 @@ void FuseOptimizerOpPass::GetFusingVarNamesMap(
for (auto &var_n : aux_vars_name) { for (auto &var_n : aux_vars_name) {
auto arg_names = node->Op()->Input(var_n); auto arg_names = node->Op()->Input(var_n);
PADDLE_ENFORCE_EQ(arg_names.size(), static_cast<size_t>(1), PADDLE_ENFORCE_EQ(arg_names.size(), static_cast<size_t>(1),
"The input variable of optimizer to be fused is " platform::errors::InvalidArgument(
"invalid. Excepted %s only has one %s input.", "The input variable of optimizer to be fused is "
node->Op()->Type(), var_n); "invalid. Excepted %s only has one %s input.",
node->Op()->Type(), var_n));
(*aux_args_name)[var_n].emplace_back(arg_names[0]); (*aux_args_name)[var_n].emplace_back(arg_names[0]);
} }
} }
...@@ -525,10 +548,14 @@ void FuseOptimizerOpPass::InsertInputAndOutputForFusedOpNode( ...@@ -525,10 +548,14 @@ void FuseOptimizerOpPass::InsertInputAndOutputForFusedOpNode(
auto deal_with_ctrl_vars = [&out_dep_vars, &not_useful_vars, auto deal_with_ctrl_vars = [&out_dep_vars, &not_useful_vars,
&fused_opt_node](ir::Node *ctr_var_node) { &fused_opt_node](ir::Node *ctr_var_node) {
PADDLE_ENFORCE_EQ(ctr_var_node->inputs.size(), 1, PADDLE_ENFORCE_EQ(ctr_var_node->inputs.size(), 1,
"The control var node has nultiple inputs."); platform::errors::InvalidArgument(
"The control var(%s) node has multiple inputs.",
ctr_var_node->Name()));
if (ctr_var_node->inputs.front() == fused_opt_node) { if (ctr_var_node->inputs.front() == fused_opt_node) {
PADDLE_ENFORCE_GT(ctr_var_node->outputs.size(), 0, PADDLE_ENFORCE_GT(
"The control var node has no output."); ctr_var_node->outputs.size(), 0,
platform::errors::InvalidArgument(
"The control var(%s) node has no output.", ctr_var_node->Name()));
auto output_ops = ctr_var_node->outputs; auto output_ops = ctr_var_node->outputs;
output_ops.erase(std::remove_if(output_ops.begin(), output_ops.end(), output_ops.erase(std::remove_if(output_ops.begin(), output_ops.end(),
[&fused_opt_node](const ir::Node *node) { [&fused_opt_node](const ir::Node *node) {
......
...@@ -35,7 +35,9 @@ class FuseSgdOpPass : public FuseOptimizerOpPass { ...@@ -35,7 +35,9 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
const std::unordered_map<std::string, std::vector<std::string>> &vars_set, const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name, const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const { const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const {
PADDLE_ENFORCE_GT(sgd_ops.size(), static_cast<size_t>(0)); PADDLE_ENFORCE_GT(
sgd_ops.size(), static_cast<size_t>(0),
platform::errors::InvalidArgument("SGD ops must not be empyt."));
// NOTE: fused_var is only exist in scope, so the graph doesn't have // NOTE: fused_var is only exist in scope, so the graph doesn't have
// fused_var node. // fused_var node.
......
...@@ -140,7 +140,7 @@ void GraphPatternDetector::ValidateByNodeRole( ...@@ -140,7 +140,7 @@ void GraphPatternDetector::ValidateByNodeRole(
subgraphs->begin(), subgraphs->end(), subgraphs->begin(), subgraphs->end(),
[](const GraphPatternDetector::subgraph_t &subgraph) -> bool { [](const GraphPatternDetector::subgraph_t &subgraph) -> bool {
// Collect the inputs and outputs. // Collect the inputs and outputs.
std::unordered_set<Node *> ios; std::set<Node *> ios;
for (auto &item : subgraph) { for (auto &item : subgraph) {
if (!item.first->IsIntermediate()) { if (!item.first->IsIntermediate()) {
ios.insert(item.second); ios.insert(item.second);
...@@ -166,7 +166,7 @@ void GraphPatternDetector::ValidateByNodeRole( ...@@ -166,7 +166,7 @@ void GraphPatternDetector::ValidateByNodeRole(
} }
struct HitGroup { struct HitGroup {
std::unordered_map<PDNode *, Node *> roles; std::map<PDNode *, Node *> roles;
bool Match(Node *node, PDNode *pat) { bool Match(Node *node, PDNode *pat) {
if (nodes_.count(node)) { if (nodes_.count(node)) {
...@@ -184,7 +184,7 @@ struct HitGroup { ...@@ -184,7 +184,7 @@ struct HitGroup {
} }
private: private:
std::unordered_set<Node *> nodes_; std::set<Node *> nodes_;
}; };
// Tell whether Node a links to b. // Tell whether Node a links to b.
...@@ -283,7 +283,7 @@ void GraphPatternDetector::UniquePatterns( ...@@ -283,7 +283,7 @@ void GraphPatternDetector::UniquePatterns(
if (subgraphs->empty()) return; if (subgraphs->empty()) return;
std::vector<GraphPatternDetector::subgraph_t> result; std::vector<GraphPatternDetector::subgraph_t> result;
std::unordered_set<size_t> set; std::set<size_t> set;
std::hash<std::string> hasher; std::hash<std::string> hasher;
for (auto &g : *subgraphs) { for (auto &g : *subgraphs) {
// Sort the items in the sub-graph, and transform to a string key. // Sort the items in the sub-graph, and transform to a string key.
...@@ -305,7 +305,7 @@ void GraphPatternDetector::UniquePatterns( ...@@ -305,7 +305,7 @@ void GraphPatternDetector::UniquePatterns(
void GraphPatternDetector::RemoveOverlappedMatch( void GraphPatternDetector::RemoveOverlappedMatch(
std::vector<subgraph_t> *subgraphs) { std::vector<subgraph_t> *subgraphs) {
std::vector<subgraph_t> result; std::vector<subgraph_t> result;
std::unordered_set<Node *> node_set; std::set<Node *> node_set;
for (const auto &subgraph : *subgraphs) { for (const auto &subgraph : *subgraphs) {
bool valid = true; bool valid = true;
......
...@@ -231,7 +231,7 @@ class PDPattern { ...@@ -231,7 +231,7 @@ class PDPattern {
std::vector<std::unique_ptr<PDNode>> nodes_; std::vector<std::unique_ptr<PDNode>> nodes_;
std::vector<edge_t> edges_; std::vector<edge_t> edges_;
std::unordered_map<std::string, PDNode*> node_map_; std::map<std::string, PDNode*> node_map_;
static size_t id_; static size_t id_;
}; };
...@@ -263,7 +263,7 @@ class PDPattern { ...@@ -263,7 +263,7 @@ class PDPattern {
*/ */
class GraphPatternDetector { class GraphPatternDetector {
public: public:
using subgraph_t = std::unordered_map<PDNode*, Node*>; using subgraph_t = std::map<PDNode*, Node*>;
// Operate on the detected pattern. // Operate on the detected pattern.
using handle_t = using handle_t =
......
...@@ -116,7 +116,10 @@ std::vector<OpHandleBase *> BufferSharedCrossOpMemoryReusePass::SortOp( ...@@ -116,7 +116,10 @@ std::vector<OpHandleBase *> BufferSharedCrossOpMemoryReusePass::SortOp(
graph_view.BreadthFirstVisit( graph_view.BreadthFirstVisit(
[&](OpHandleBase *cur_op) { sorted_ops.emplace_back(cur_op); }); [&](OpHandleBase *cur_op) { sorted_ops.emplace_back(cur_op); });
PADDLE_ENFORCE_EQ(sorted_ops.size(), graph_view.OpNumber(), PADDLE_ENFORCE_EQ(sorted_ops.size(), graph_view.OpNumber(),
"There are unvisited ops"); platform::errors::InvalidArgument(
"Sorted ops size(%d) not equal to graph op size(%d). "
"There are unvisited ops.",
sorted_ops.size(), graph_view.OpNumber()));
return sorted_ops; return sorted_ops;
} }
...@@ -181,7 +184,9 @@ void BufferSharedCrossOpMemoryReusePass::RunOnScopeIdx(size_t idx) const { ...@@ -181,7 +184,9 @@ void BufferSharedCrossOpMemoryReusePass::RunOnScopeIdx(size_t idx) const {
auto *out_node = *(out_nodes.begin()); auto *out_node = *(out_nodes.begin());
auto *out_var = auto *out_var =
dynamic_cast<VarHandle *>(&(out_node->Wrapper<VarHandleBase>())); dynamic_cast<VarHandle *>(&(out_node->Wrapper<VarHandleBase>()));
PADDLE_ENFORCE_NOT_NULL(out_var); PADDLE_ENFORCE_NOT_NULL(
out_var, platform::errors::NotFound(
"Can not find a valid Var Node for Var %s.", out_arg));
// If out_arg is not reusable, skip it // If out_arg is not reusable, skip it
if (!IsOutVarReusable(*out_var)) { if (!IsOutVarReusable(*out_var)) {
...@@ -269,7 +274,8 @@ size_t BufferSharedCrossOpMemoryReusePass::ResolveDependencyBetween( ...@@ -269,7 +274,8 @@ size_t BufferSharedCrossOpMemoryReusePass::ResolveDependencyBetween(
auto op_dep = GetOpDep(prev_op, op); auto op_dep = GetOpDep(prev_op, op);
if (op_dep == NodeDependency::kBefore) continue; if (op_dep == NodeDependency::kBefore) continue;
PADDLE_ENFORCE_EQ(op_dep, NodeDependency::kNoDep, PADDLE_ENFORCE_EQ(op_dep, NodeDependency::kNoDep,
"The graph has circle, this may be a bug"); platform::errors::InvalidArgument(
"The graph has circle, this may be a bug."));
auto iter = auto iter =
std::find_if(prev_op->Outputs().begin(), prev_op->Outputs().end(), std::find_if(prev_op->Outputs().begin(), prev_op->Outputs().end(),
...@@ -316,9 +322,13 @@ size_t BufferSharedCrossOpMemoryReusePass::ResolveDependencyBetween( ...@@ -316,9 +322,13 @@ size_t BufferSharedCrossOpMemoryReusePass::ResolveDependencyBetween(
} }
void BufferSharedCrossOpMemoryReusePass::BuildOpDependencyMap() const { void BufferSharedCrossOpMemoryReusePass::BuildOpDependencyMap() const {
PADDLE_ENFORCE(ops_.empty(), "ops_ must be initialized here"); PADDLE_ENFORCE_EQ(ops_.empty(), true, platform::errors::InvalidArgument(
PADDLE_ENFORCE(op_to_idx_.empty(), "op_to_idx_ must be initialized here"); "Ops must be initialized here."));
PADDLE_ENFORCE(deps_.empty(), "deps_ must be initialized here"); PADDLE_ENFORCE_EQ(
op_to_idx_.empty(), true,
platform::errors::InvalidArgument("Op to idx must be initialized here."));
PADDLE_ENFORCE_EQ(deps_.empty(), true, platform::errors::InvalidArgument(
"Deps must be initialized here."));
// Toposort ops // Toposort ops
OpGraphView graph_view(ir::FilterByNodeWrapper<OpHandleBase>(*graph_)); OpGraphView graph_view(ir::FilterByNodeWrapper<OpHandleBase>(*graph_));
...@@ -344,7 +354,10 @@ void BufferSharedCrossOpMemoryReusePass::BuildOpDependencyMap() const { ...@@ -344,7 +354,10 @@ void BufferSharedCrossOpMemoryReusePass::BuildOpDependencyMap() const {
prev_preceding_ops.end()); prev_preceding_ops.end());
} }
}); });
PADDLE_ENFORCE_EQ(preceding_ops.size(), op_num); PADDLE_ENFORCE_EQ(preceding_ops.size(), op_num,
platform::errors::InvalidArgument(
"Preceding ops size(%d) must equal to op num(%d).",
preceding_ops.size(), op_num));
// Find out ComputationOpHandles only // Find out ComputationOpHandles only
ops_.resize(scope_num); ops_.resize(scope_num);
...@@ -384,28 +397,43 @@ void BufferSharedCrossOpMemoryReusePass::BuildOpDependencyMap() const { ...@@ -384,28 +397,43 @@ void BufferSharedCrossOpMemoryReusePass::BuildOpDependencyMap() const {
size_t BufferSharedCrossOpMemoryReusePass::OpIndex( size_t BufferSharedCrossOpMemoryReusePass::OpIndex(
const ComputationOpHandle *op) const { const ComputationOpHandle *op) const {
auto iter = op_to_idx_[op->GetScopeIdx()].find(op); auto iter = op_to_idx_[op->GetScopeIdx()].find(op);
PADDLE_ENFORCE(iter != op_to_idx_[op->GetScopeIdx()].end()); PADDLE_ENFORCE_NE(iter, op_to_idx_[op->GetScopeIdx()].end(),
platform::errors::NotFound(
"Can not find op(%s) in op_to_idx_.", op->Name()));
return iter->second; return iter->second;
} }
NodeDependency BufferSharedCrossOpMemoryReusePass::GetOpDep( NodeDependency BufferSharedCrossOpMemoryReusePass::GetOpDep(
const ComputationOpHandle *op1, const ComputationOpHandle *op2) const { const ComputationOpHandle *op1, const ComputationOpHandle *op2) const {
PADDLE_ENFORCE_EQ(op1->GetScopeIdx(), op2->GetScopeIdx()); PADDLE_ENFORCE_EQ(op1->GetScopeIdx(), op2->GetScopeIdx(),
platform::errors::InvalidArgument(
"Op(%s) and op(%s) must in the same scope.",
op1->Name(), op2->Name()));
return deps_[op1->GetScopeIdx()][OpIndex(op1)][OpIndex(op2)]; return deps_[op1->GetScopeIdx()][OpIndex(op1)][OpIndex(op2)];
} }
void BufferSharedCrossOpMemoryReusePass::SetOpDep( void BufferSharedCrossOpMemoryReusePass::SetOpDep(
const ComputationOpHandle *op1, const ComputationOpHandle *op2, const ComputationOpHandle *op1, const ComputationOpHandle *op2,
NodeDependency dep) const { NodeDependency dep) const {
PADDLE_ENFORCE_EQ(op1->GetScopeIdx(), op2->GetScopeIdx()); PADDLE_ENFORCE_EQ(op1->GetScopeIdx(), op2->GetScopeIdx(),
platform::errors::InvalidArgument(
"Op(%s) and op(%s) must in the same scope.",
op1->Name(), op2->Name()));
if (op1 == op2) { if (op1 == op2) {
PADDLE_ENFORCE(dep == NodeDependency::kSame); PADDLE_ENFORCE_EQ(
dep, NodeDependency::kSame,
platform::errors::InvalidArgument(
"Set Same Op(%s) Dep, dep must be kSame type.", op1->Name()));
auto idx = OpIndex(op1); auto idx = OpIndex(op1);
deps_[op1->GetScopeIdx()][idx][idx] = NodeDependency::kSame; deps_[op1->GetScopeIdx()][idx][idx] = NodeDependency::kSame;
} else { } else {
auto idx1 = OpIndex(op1); auto idx1 = OpIndex(op1);
auto idx2 = OpIndex(op2); auto idx2 = OpIndex(op2);
PADDLE_ENFORCE(dep != NodeDependency::kSame && idx1 != idx2); PADDLE_ENFORCE_EQ((dep != NodeDependency::kSame && idx1 != idx2), true,
platform::errors::InvalidArgument(
"Op(%s) and Op(%s) should not have same "
"index(%d), and dep should not kSame type.",
op1->Name(), op2->Name(), idx1));
deps_[op1->GetScopeIdx()][idx1][idx2] = dep; deps_[op1->GetScopeIdx()][idx1][idx2] = dep;
deps_[op1->GetScopeIdx()][idx2][idx1] = ReverseNodeDependency(dep); deps_[op1->GetScopeIdx()][idx2][idx1] = ReverseNodeDependency(dep);
} }
......
...@@ -57,7 +57,9 @@ void BufferSharedInplaceOpPass::Run(Graph *graph) const { ...@@ -57,7 +57,9 @@ void BufferSharedInplaceOpPass::Run(Graph *graph) const {
auto *op = *(pair.second.ops().begin()); auto *op = *(pair.second.ops().begin());
const std::string &op_type = op->GetOp()->Type(); const std::string &op_type = op->GetOp()->Type();
const framework::OpDesc *op_desc = op->Node()->Op(); const framework::OpDesc *op_desc = op->Node()->Op();
PADDLE_ENFORCE_NOT_NULL(op_desc); PADDLE_ENFORCE_NOT_NULL(
op_desc, platform::errors::NotFound("Op(%s) can not find opdesc.",
op->Name()));
auto &infer_inplace = OpInfoMap::Instance().Get(op_type).infer_inplace_; auto &infer_inplace = OpInfoMap::Instance().Get(op_type).infer_inplace_;
if (!infer_inplace) { if (!infer_inplace) {
......
...@@ -58,8 +58,12 @@ static int64_t GetMemorySize( ...@@ -58,8 +58,12 @@ static int64_t GetMemorySize(
&vars, &vars,
const std::string &var_name) { const std::string &var_name) {
auto *var_desc = TryGetLatestVarDesc(vars.at(var_name)); auto *var_desc = TryGetLatestVarDesc(vars.at(var_name));
PADDLE_ENFORCE_NOT_NULL(var_desc); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(IsLoDTensor(var_desc)); var_desc,
platform::errors::NotFound("Var(%s) can not find VarDesc.", var_name));
PADDLE_ENFORCE_EQ(IsLoDTensor(var_desc), true,
platform::errors::InvalidArgument(
"Var(%s) must be LoDTensor.", var_name));
auto dims = var_desc->GetShape(); auto dims = var_desc->GetShape();
return SizeOfType(var_desc->GetDataType()) * return SizeOfType(var_desc->GetDataType()) *
std::accumulate(dims.begin(), dims.end(), static_cast<int64_t>(1), std::accumulate(dims.begin(), dims.end(), static_cast<int64_t>(1),
......
...@@ -42,8 +42,10 @@ class MemOptVarInfo { ...@@ -42,8 +42,10 @@ class MemOptVarInfo {
} }
void SetRefCnt(size_t ref_cnt) { void SetRefCnt(size_t ref_cnt) {
PADDLE_ENFORCE_GE(ref_cnt, 1, PADDLE_ENFORCE_GE(
"Reference count must be larger than or equal to 1"); ref_cnt, 1,
platform::errors::InvalidArgument(
"Reference count(%d) must be larger than or equal to 1.", ref_cnt));
ref_cnt_ = ref_cnt; ref_cnt_ = ref_cnt;
runtime_ref_cnt_ = ref_cnt; runtime_ref_cnt_ = ref_cnt;
} }
......
...@@ -66,7 +66,11 @@ bool MemoryReusePass::TryReuseVar(details::VarHandle *in_var, ...@@ -66,7 +66,11 @@ bool MemoryReusePass::TryReuseVar(details::VarHandle *in_var,
details::VarHandle *out_var) const { details::VarHandle *out_var) const {
auto *op = auto *op =
dynamic_cast<details::ComputationOpHandle *>(out_var->GeneratedOp()); dynamic_cast<details::ComputationOpHandle *>(out_var->GeneratedOp());
PADDLE_ENFORCE_NOT_NULL(op); PADDLE_ENFORCE_NOT_NULL(
op,
platform::errors::InvalidArgument(
"Var(%s) have no GeneratedOp, or it's op is not ComputationOpHandle.",
out_var->Name()));
if (IsVarPairReusable(*in_var, *out_var)) { if (IsVarPairReusable(*in_var, *out_var)) {
AddReuseVar(op, in_var, out_var); AddReuseVar(op, in_var, out_var);
return true; return true;
...@@ -91,10 +95,13 @@ VarDesc *MemoryReusePass::GetVarDesc(const details::VarHandle &var) const { ...@@ -91,10 +95,13 @@ VarDesc *MemoryReusePass::GetVarDesc(const details::VarHandle &var) const {
size_t scope_idx = var.scope_idx(); size_t scope_idx = var.scope_idx();
auto iter = var_descs_[scope_idx].find(var_name); auto iter = var_descs_[scope_idx].find(var_name);
if (iter == var_descs_[scope_idx].end()) { if (iter == var_descs_[scope_idx].end()) {
PADDLE_ENFORCE((*all_vars_)[scope_idx].count(var_name), PADDLE_ENFORCE_NE(
"Variable %s not found", var_name); (*all_vars_)[scope_idx].count(var_name), 0,
platform::errors::NotFound("Variable %s not found.", var_name));
auto *desc = TryGetLatestVarDesc((*all_vars_)[scope_idx].at(var_name)); auto *desc = TryGetLatestVarDesc((*all_vars_)[scope_idx].at(var_name));
PADDLE_ENFORCE_NOT_NULL(desc); PADDLE_ENFORCE_NOT_NULL(
desc,
platform::errors::NotFound("Var(%s) can not find VarDesc.", var_name));
var_descs_[scope_idx].emplace(var_name, desc); var_descs_[scope_idx].emplace(var_name, desc);
return desc; return desc;
} else { } else {
...@@ -119,7 +126,9 @@ void MemoryReusePass::CollectShareTensorBufferOpHandles() const { ...@@ -119,7 +126,9 @@ void MemoryReusePass::CollectShareTensorBufferOpHandles() const {
if (share_buffer_op != nullptr) { if (share_buffer_op != nullptr) {
auto *compute_op = auto *compute_op =
details::GetUniquePendingComputationOpHandle(share_buffer_op); details::GetUniquePendingComputationOpHandle(share_buffer_op);
PADDLE_ENFORCE(ops_.count(compute_op) == 0); PADDLE_ENFORCE_EQ(
ops_.count(compute_op), 0,
platform::errors::AlreadyExists("Compute op already exists."));
ops_.emplace(compute_op, share_buffer_op); ops_.emplace(compute_op, share_buffer_op);
} }
} }
...@@ -227,8 +236,11 @@ bool MemoryReusePass::IsInVarReusable(const details::VarHandle &in_var) const { ...@@ -227,8 +236,11 @@ bool MemoryReusePass::IsInVarReusable(const details::VarHandle &in_var) const {
*/ */
bool MemoryReusePass::IsOutVarReusable( bool MemoryReusePass::IsOutVarReusable(
const details::VarHandle &out_var) const { const details::VarHandle &out_var) const {
PADDLE_ENFORCE_NOT_NULL(dynamic_cast<const details::ComputationOpHandle *>( PADDLE_ENFORCE_NOT_NULL(
out_var.GeneratedOp())); dynamic_cast<const details::ComputationOpHandle *>(out_var.GeneratedOp()),
platform::errors::InvalidArgument(
"Var(%s) have no GeneratedOp, or it's op is not ComputationOpHandle.",
out_var.Name()));
const auto out_name = out_var.Name(); const auto out_name = out_var.Name();
if (out_name == kEmptyVarName) { if (out_name == kEmptyVarName) {
return false; return false;
...@@ -236,9 +248,10 @@ bool MemoryReusePass::IsOutVarReusable( ...@@ -236,9 +248,10 @@ bool MemoryReusePass::IsOutVarReusable(
// out_var must be the first version!!! // out_var must be the first version!!!
auto out_var_iter = (*all_vars_)[out_var.scope_idx()].find(out_name); auto out_var_iter = (*all_vars_)[out_var.scope_idx()].find(out_name);
PADDLE_ENFORCE(out_var_iter != (*all_vars_)[out_var.scope_idx()].end() && PADDLE_ENFORCE_EQ(
!out_var_iter->second.empty(), (out_var_iter != (*all_vars_)[out_var.scope_idx()].end() &&
"Cannot find variable %s", out_name); !out_var_iter->second.empty()),
true, platform::errors::NotFound("Cannot find variable %s.", out_name));
if (out_var_iter->second[0] != &out_var) { if (out_var_iter->second[0] != &out_var) {
return false; return false;
...@@ -282,7 +295,11 @@ bool MemoryReusePass::IsVarPairReusable( ...@@ -282,7 +295,11 @@ bool MemoryReusePass::IsVarPairReusable(
const details::VarHandle &in_var, const details::VarHandle &out_var) const { const details::VarHandle &in_var, const details::VarHandle &out_var) const {
auto *op = auto *op =
dynamic_cast<const details::ComputationOpHandle *>(out_var.GeneratedOp()); dynamic_cast<const details::ComputationOpHandle *>(out_var.GeneratedOp());
PADDLE_ENFORCE_NOT_NULL(op); PADDLE_ENFORCE_NOT_NULL(
op,
platform::errors::InvalidArgument(
"Var(%s) have no GeneratedOp, or it's op is not ComputationOpHandle.",
out_var.Name()));
const auto in_name = in_var.Name(); const auto in_name = in_var.Name();
if (in_name == out_var.Name()) { if (in_name == out_var.Name()) {
...@@ -308,8 +325,10 @@ bool MemoryReusePass::IsVarPairReusable( ...@@ -308,8 +325,10 @@ bool MemoryReusePass::IsVarPairReusable(
void MemoryReusePass::AddReuseVar(details::ComputationOpHandle *op, void MemoryReusePass::AddReuseVar(details::ComputationOpHandle *op,
details::VarHandle *in_var, details::VarHandle *in_var,
details::VarHandle *out_var) const { details::VarHandle *out_var) const {
PADDLE_ENFORCE((*var_infos_)[op->GetScopeIdx()].count(in_var->Name()) > 0, PADDLE_ENFORCE_GT(
"%s does not in mem-opt var infos", in_var->Name()); (*var_infos_)[op->GetScopeIdx()].count(in_var->Name()), 0,
platform::errors::NotFound("Var(%s) does not in mem opt var infos.",
in_var->Name()));
if (ops_.count(op) == 0) { if (ops_.count(op) == 0) {
InsertShareTensorBufferOpHandleToGraph(op); InsertShareTensorBufferOpHandleToGraph(op);
...@@ -349,7 +368,10 @@ void MemoryReusePass::UpdateLastLiveOpOfVar(details::ComputationOpHandle *op, ...@@ -349,7 +368,10 @@ void MemoryReusePass::UpdateLastLiveOpOfVar(details::ComputationOpHandle *op,
if (out_var_op_iter == (*last_live_ops_of_vars_)[scope_idx].end()) { if (out_var_op_iter == (*last_live_ops_of_vars_)[scope_idx].end()) {
last_live_op_of_in_var = op; last_live_op_of_in_var = op;
} else { } else {
PADDLE_ENFORCE(!out_var_op_iter->second.ops().empty()); PADDLE_ENFORCE_EQ(
out_var_op_iter->second.ops().empty(), false,
platform::errors::InvalidArgument(
"Var(%s)'s last live op should not empty.", out_var->Name()));
last_live_op_of_in_var = *(out_var_op_iter->second.ops().begin()); last_live_op_of_in_var = *(out_var_op_iter->second.ops().begin());
} }
...@@ -359,8 +381,9 @@ void MemoryReusePass::UpdateLastLiveOpOfVar(details::ComputationOpHandle *op, ...@@ -359,8 +381,9 @@ void MemoryReusePass::UpdateLastLiveOpOfVar(details::ComputationOpHandle *op,
last_live_ops_of_in_var->insert(last_live_op_of_in_var); last_live_ops_of_in_var->insert(last_live_op_of_in_var);
auto in_var_info_iter = (*var_infos_)[scope_idx].find(in_var->Name()); auto in_var_info_iter = (*var_infos_)[scope_idx].find(in_var->Name());
PADDLE_ENFORCE(in_var_info_iter != (*var_infos_)[scope_idx].end(), PADDLE_ENFORCE_NE(
"Cannot find variable %s", in_var->Name()); in_var_info_iter, (*var_infos_)[scope_idx].end(),
platform::errors::NotFound("Cannot find variable %s.", in_var->Name()));
in_var_info_iter->second->SetRefCnt(1); in_var_info_iter->second->SetRefCnt(1);
} }
......
...@@ -39,7 +39,7 @@ void OpGraphView::Build(const std::vector<details::OpHandleBase *> &ops) { ...@@ -39,7 +39,7 @@ void OpGraphView::Build(const std::vector<details::OpHandleBase *> &ops) {
} }
PADDLE_ENFORCE( PADDLE_ENFORCE(
preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(), preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(),
"There are duplicate ops in graph."); platform::errors::InvalidArgument("There are duplicate ops in graph."));
} }
std::unordered_set<details::OpHandleBase *> OpGraphView::AllOps() const { std::unordered_set<details::OpHandleBase *> OpGraphView::AllOps() const {
...@@ -56,8 +56,10 @@ bool OpGraphView::HasOp(details::OpHandleBase *op) const { ...@@ -56,8 +56,10 @@ bool OpGraphView::HasOp(details::OpHandleBase *op) const {
} }
void OpGraphView::EnforceHasOp(details::OpHandleBase *op) const { void OpGraphView::EnforceHasOp(details::OpHandleBase *op) const {
PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView", PADDLE_ENFORCE_EQ(HasOp(op), true,
op == nullptr ? "nullptr" : op->DebugString()); platform::errors::NotFound(
"Cannot find op %s in OpGraphView.",
op == nullptr ? "nullptr" : op->DebugString()));
} }
const std::unordered_set<details::OpHandleBase *> &OpGraphView::PendingOps( const std::unordered_set<details::OpHandleBase *> &OpGraphView::PendingOps(
......
...@@ -127,9 +127,13 @@ void OpGraphView::BreadthFirstVisit(Callback &&callback) const { ...@@ -127,9 +127,13 @@ void OpGraphView::BreadthFirstVisit(Callback &&callback) const {
} }
} }
PADDLE_ENFORCE_EQ(num_calls, op_num, "There are unvisited ops"); PADDLE_ENFORCE_EQ(num_calls, op_num, platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(visited_ops.size(), op_num, "There are unvisited ops"); "There are unvisited ops."));
PADDLE_ENFORCE(op_deps.empty(), "There are unvisited ops"); PADDLE_ENFORCE_EQ(
visited_ops.size(), op_num,
platform::errors::InvalidArgument("There are unvisited ops."));
PADDLE_ENFORCE_EQ(op_deps.empty(), true, platform::errors::InvalidArgument(
"There are unvisited ops."));
} }
} // namespace ir } // namespace ir
......
...@@ -77,11 +77,15 @@ class ShrinkDepsOpFunctor { ...@@ -77,11 +77,15 @@ class ShrinkDepsOpFunctor {
const std::vector<details::OpHandleBase *> &ops) const { const std::vector<details::OpHandleBase *> &ops) const {
std::unordered_map<details::OpHandleBase *, size_t> op_to_idx; std::unordered_map<details::OpHandleBase *, size_t> op_to_idx;
for (size_t i = 0; i < ops.size(); ++i) { for (size_t i = 0; i < ops.size(); ++i) {
PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph"); PADDLE_ENFORCE_EQ(
graph_.HasOp(ops[i]), true,
platform::errors::InvalidArgument("Op does not exist in graph."));
op_to_idx[ops[i]] = i; op_to_idx[ops[i]] = i;
} }
PADDLE_ENFORCE(op_to_idx.size() == ops.size(), "Duplicate ops"); PADDLE_ENFORCE_EQ(
op_to_idx.size(), ops.size(),
platform::errors::InvalidArgument("Graph may have duplicate ops."));
std::vector<std::vector<RelationShip>> ret(ops.size()); std::vector<std::vector<RelationShip>> ret(ops.size());
for (auto &e : ret) { for (auto &e : ret) {
...@@ -247,9 +251,9 @@ ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx, ...@@ -247,9 +251,9 @@ ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx,
return {}; return {};
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(computation_ops.empty(), false,
computation_ops.empty(), false, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("Computation ops should not be empty")); "Computation ops should not be empty."));
// stage four. Try to shrink computation op if they depend on each other. // stage four. Try to shrink computation op if they depend on each other.
// Get the smallest set of the most ops. // Get the smallest set of the most ops.
...@@ -263,8 +267,9 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ...@@ -263,8 +267,9 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars); Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
PADDLE_ENFORCE(last_live_ops_of_vars.empty() && var_infos.empty(), PADDLE_ENFORCE(last_live_ops_of_vars.empty() && var_infos.empty(),
"Last Live Ops and Reference Counts of vars should be " platform::errors::InvalidArgument(
"initialized at here."); "Last live ops and reference counts of vars should be "
"initialized at here."));
const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars); const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars);
...@@ -304,11 +309,15 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ...@@ -304,11 +309,15 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
auto &var_name = name_var_pair.first; auto &var_name = name_var_pair.first;
auto &var_handles = name_var_pair.second; auto &var_handles = name_var_pair.second;
PADDLE_ENFORCE_EQ(var_desc->Name(), var_name);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
var_handles.empty(), false, var_desc->Name(), var_name,
platform::errors::InvalidArgument("Variable %s not found", var_name)); platform::errors::InvalidArgument(
"A Var, it's VarName(%s) and DescName(%s) not same.", var_name,
var_desc->Name()));
PADDLE_ENFORCE_EQ(var_handles.empty(), false,
platform::errors::InvalidArgument(
"Variable %s not found.", var_name));
auto last_ver_var = var_handles.back(); auto last_ver_var = var_handles.back();
if (last_ver_var->Node()->IsCtrlVar()) { if (last_ver_var->Node()->IsCtrlVar()) {
...@@ -327,12 +336,13 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ...@@ -327,12 +336,13 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
continue; continue;
} }
PADDLE_ENFORCE_EQ(status, LastLiveOpSearchStatus::kSuccess,
platform::errors::InvalidArgument(
"Status(%d) must be success.", status));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
status, LastLiveOpSearchStatus::kSuccess, result.empty(), false,
platform::errors::InvalidArgument("status must be success")); platform::errors::NotFound("Last living ops of %s cannot be empty.",
PADDLE_ENFORCE_EQ(result.empty(), false, var_name));
platform::errors::NotFound(
"Last living ops of %s cannot be empty", var_name));
std::string last_live_ops_log_str; std::string last_live_ops_log_str;
for (auto &each_ret : result) { for (auto &each_ret : result) {
......
...@@ -22,7 +22,8 @@ namespace framework { ...@@ -22,7 +22,8 @@ namespace framework {
namespace ir { namespace ir {
void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph, "graph cannot be nullptr."); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("conv_activation_mkldnn_fuse", graph); FusePassBase::Init("conv_activation_mkldnn_fuse", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -75,7 +76,8 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -75,7 +76,8 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
GraphSafeRemoveNodes(graph, {activation, conv_out}); GraphSafeRemoveNodes(graph, {activation, conv_out});
PADDLE_ENFORCE_GT(subgraph.count(conv_input), 0UL, PADDLE_ENFORCE_GT(subgraph.count(conv_input), 0UL,
"subgraph has to contain conv_input node."); platform::errors::InvalidArgument(
"Subgraph has to contain conv input node."));
IR_NODE_LINK_TO(conv, activation_out); IR_NODE_LINK_TO(conv, activation_out);
found_conv_activation_count++; found_conv_activation_count++;
}; };
......
...@@ -26,7 +26,11 @@ namespace ir { ...@@ -26,7 +26,11 @@ namespace ir {
template <typename BinaryOperation> template <typename BinaryOperation>
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b, LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
BinaryOperation f) { BinaryOperation f) {
PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims()); PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims(),
platform::errors::InvalidArgument(
"Input two tensors must have same shape, but they are "
"different: %s, %s.",
vec_a.dims(), vec_b.dims()));
LoDTensor vec_y; LoDTensor vec_y;
vec_y.Resize(vec_a.dims()); vec_y.Resize(vec_a.dims());
const float* a = vec_a.data<float>(); const float* a = vec_a.data<float>();
...@@ -39,11 +43,13 @@ LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b, ...@@ -39,11 +43,13 @@ LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
} }
void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = auto* conv_input =
...@@ -68,7 +74,9 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -68,7 +74,9 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
// elementwise_add op // elementwise_add op
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern);
PADDLE_ENFORCE(subgraph.count(conv_input)); PADDLE_ENFORCE_NE(
subgraph.count(conv_input), 0,
platform::errors::NotFound("Detector did not find conv input."));
// check if fuse can be done and if MKL-DNN should be used // check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise); FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
...@@ -86,10 +94,16 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -86,10 +94,16 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
if (has_bias && conv->Op()->Input("Bias").size() > 0) { if (has_bias && conv->Op()->Input("Bias").size() > 0) {
auto conv_bias_names = conv->Op()->Input("Bias"); auto conv_bias_names = conv->Op()->Input("Bias");
// add eltwise bias to existing conv bias // add eltwise bias to existing conv bias
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1); PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1,
platform::errors::NotFound("Can not find var Bias."));
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]); auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>(); auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), eltwise_bias_tensor->dims()); PADDLE_ENFORCE_EQ(
conv_bias_tensor->dims(), eltwise_bias_tensor->dims(),
platform::errors::InvalidArgument(
"Conv bias tensor and eltwise bias tensor "
"must have same shape, but they are different: %s, %s.",
conv_bias_tensor->dims(), eltwise_bias_tensor->dims()));
*conv_bias_tensor = tensor_apply_eltwise( *conv_bias_tensor = tensor_apply_eltwise(
*conv_bias_tensor, *eltwise_bias_tensor, std::plus<float>()); *conv_bias_tensor, *eltwise_bias_tensor, std::plus<float>());
......
...@@ -39,7 +39,10 @@ void ConvConcatReLUFusePass::FindConcatWithConvs( ...@@ -39,7 +39,10 @@ void ConvConcatReLUFusePass::FindConcatWithConvs(
for (auto node : concat_inputs) { for (auto node : concat_inputs) {
auto prev_op_node = node->inputs; auto prev_op_node = node->inputs;
PADDLE_ENFORCE_EQ(prev_op_node.size(), 1); PADDLE_ENFORCE_EQ(prev_op_node.size(), 1,
platform::errors::InvalidArgument(
"Node(%s) input size(%d) must be 1.", node->Name(),
prev_op_node.size()));
auto* conv_op = prev_op_node[0]; auto* conv_op = prev_op_node[0];
if (conv_op->Op()->Type() != "conv2d") return; if (conv_op->Op()->Type() != "conv2d") return;
...@@ -103,7 +106,8 @@ void ConvConcatReLUFusePass::FuseConvConcatReLU( ...@@ -103,7 +106,8 @@ void ConvConcatReLUFusePass::FuseConvConcatReLU(
} }
void ConvConcatReLUFusePass::ApplyImpl(ir::Graph* graph) const { void ConvConcatReLUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
std::unordered_map<const Node*, int> concat_with_convs_counter; std::unordered_map<const Node*, int> concat_with_convs_counter;
......
...@@ -37,18 +37,24 @@ void UnlinkNodes(ir::Node* a, ir::Node* b) { ...@@ -37,18 +37,24 @@ void UnlinkNodes(ir::Node* a, ir::Node* b) {
b->inputs.end()); b->inputs.end());
} }
void LogCannotQuantizeOp(Node* op) { void LogCannotQuantizeOp(Node* op, const char* details = nullptr) {
std::stringstream msg_ss; std::stringstream msg_ss;
msg_ss << "Cannot quantize operator " << op->Name() msg_ss << "Cannot quantize operator " << op->Name()
<< " (type: " << op->Op()->Type() << ", id: " << op->id() << ")."; << " (type: " << op->Op()->Type() << ", id: " << op->id() << ").";
if (details) msg_ss << " " << details;
PrettyLogDetail(msg_ss.str().c_str()); PrettyLogDetail(msg_ss.str().c_str());
} }
void LogScaleIsMissingForVar(Node* var) { void LogScaleIsMissingForVar(Node* var) {
VLOG(4) << "Quantization scale for the variable " << var->Name()
<< " is missing.";
}
void LogQuantizationDisabled(Node* op) {
std::stringstream msg_ss; std::stringstream msg_ss;
msg_ss << "Quantization scale for the variable " << var->Name() VLOG(4) << "Qantization skipped for operator " << op->Name()
<< " is missing."; << " (type: " << op->Op()->Type() << ", id: " << op->id()
PrettyLogDetail(msg_ss.str().c_str()); << "). Attribute use_quantizer = false.";
} }
} // namespace } // namespace
...@@ -62,10 +68,10 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input, ...@@ -62,10 +68,10 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input,
auto inputs = op->Op()->InputNames(); auto inputs = op->Op()->InputNames();
bool name_found = bool name_found =
std::find(inputs.begin(), inputs.end(), input_name) != inputs.end(); std::find(inputs.begin(), inputs.end(), input_name) != inputs.end();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(name_found, true,
name_found, true, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("%s isn't the input of the %s operator", "Var(%s) isn't the input of the %s operator.",
input_name, op->Op()->Type())); input_name, op->Op()->Type()));
unsigned max = is_unsigned ? U8_MAX : S8_MAX; unsigned max = is_unsigned ? U8_MAX : S8_MAX;
float scale = scale_to_one * max; float scale = scale_to_one * max;
...@@ -104,8 +110,14 @@ void CPUQuantizePass::QuantizeInputs(Graph* g, Node* op, std::string input_name, ...@@ -104,8 +110,14 @@ void CPUQuantizePass::QuantizeInputs(Graph* g, Node* op, std::string input_name,
std::string scale_attr_name) const { std::string scale_attr_name) const {
auto inputs = op->inputs; auto inputs = op->inputs;
auto output = op->outputs[0]; auto output = op->outputs[0];
PADDLE_ENFORCE_GE(inputs.size(), 1); PADDLE_ENFORCE_GE(inputs.size(), 1,
PADDLE_ENFORCE_EQ(op->outputs.size(), 1); platform::errors::InvalidArgument(
"OP(%s)'s inputs(%d) must be equal or greater than 1.",
op->Name(), inputs.size()));
PADDLE_ENFORCE_EQ(op->outputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s outputs(%d) must be equal to 1.", op->Name(),
op->outputs.size()));
// create a quantize op desc prototype // create a quantize op desc prototype
OpDesc q_desc; OpDesc q_desc;
...@@ -153,8 +165,8 @@ void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output, ...@@ -153,8 +165,8 @@ void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output,
std::find(outputs.begin(), outputs.end(), output_name) != outputs.end(); std::find(outputs.begin(), outputs.end(), output_name) != outputs.end();
PADDLE_ENFORCE_EQ(name_found, true, PADDLE_ENFORCE_EQ(name_found, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"%s isn't the output of the %s operator", output_name, "Var(%s) isn't the output of the %s operator.",
op->Op()->Type())); output_name, op->Op()->Type()));
unsigned max = is_unsigned ? U8_MAX : S8_MAX; unsigned max = is_unsigned ? U8_MAX : S8_MAX;
float scale = scale_to_one * max; float scale = scale_to_one * max;
...@@ -239,12 +251,23 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, ...@@ -239,12 +251,23 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
auto* conv_op_desc = conv_op->Op(); auto* conv_op_desc = conv_op->Op();
// skip if should not be quantized // skip if should not be quantized
if (!conv_op_desc->GetAttrIfExists<bool>("use_quantizer")) return; if (!conv_op_desc->GetAttrIfExists<bool>("use_quantizer")) {
LogQuantizationDisabled(conv_op);
return;
}
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
auto has_output_scale = AreScalesPresentForNodes(conv_op, {conv_output});
if (with_residual_data && !has_output_scale) {
LogCannotQuantizeOp(conv_op,
"Conv op with ResidualData input cannot be quantized "
"without output scale.");
return;
}
if (with_residual_data) { if (with_residual_data) {
GET_IR_NODE_FROM_SUBGRAPH(conv_residual_data, conv_residual_data, GET_IR_NODE_FROM_SUBGRAPH(conv_residual_data, conv_residual_data,
conv_pattern); conv_pattern);
...@@ -283,7 +306,7 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, ...@@ -283,7 +306,7 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
conv_op->Op()->SetAttr("Scale_weights", filter_scale); conv_op->Op()->SetAttr("Scale_weights", filter_scale);
// if quantization scale is missing for output tensor, return fp32 data // if quantization scale is missing for output tensor, return fp32 data
if (AreScalesPresentForNodes(conv_op, {conv_output})) { if (has_output_scale) {
bool is_output_unsigned{false}; bool is_output_unsigned{false};
auto output_scale = auto output_scale =
GetScaleValueForNode(conv_output, &is_output_unsigned); GetScaleValueForNode(conv_output, &is_output_unsigned);
...@@ -333,9 +356,13 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const { ...@@ -333,9 +356,13 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const {
auto* fc_op_desc = fc->Op(); auto* fc_op_desc = fc->Op();
// skip if should not be quantized // skip if should not be quantized
if (fc_op_desc->GetAttrIfExists<bool>("use_quantizer") != true || if (!fc_op_desc->GetAttrIfExists<bool>("use_quantizer")) {
fc_op_desc->GetAttrIfExists<bool>("use_mkldnn") != true) LogQuantizationDisabled(fc);
return;
}
if (!fc_op_desc->GetAttrIfExists<bool>("use_mkldnn")) {
return; return;
}
GET_IR_NODE_FROM_SUBGRAPH(weights, weights, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(weights, weights, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(input, input, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(input, input, fc_pattern);
...@@ -396,7 +423,10 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const { ...@@ -396,7 +423,10 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const {
auto* pool_op_desc = pool_op->Op(); auto* pool_op_desc = pool_op->Op();
// skip if should not be quantized // skip if should not be quantized
if (!pool_op_desc->GetAttrIfExists<bool>("use_quantizer")) return; if (!pool_op_desc->GetAttrIfExists<bool>("use_quantizer")) {
LogQuantizationDisabled(pool_op);
return;
}
GET_IR_NODE_FROM_SUBGRAPH(pool_input, pool_input, pool_pattern); GET_IR_NODE_FROM_SUBGRAPH(pool_input, pool_input, pool_pattern);
GET_IR_NODE_FROM_SUBGRAPH(pool_output, pool_output, pool_pattern); GET_IR_NODE_FROM_SUBGRAPH(pool_output, pool_output, pool_pattern);
...@@ -438,7 +468,10 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const { ...@@ -438,7 +468,10 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
auto* concat_op_desc = concat_op->Op(); auto* concat_op_desc = concat_op->Op();
// skip if should not be quantized // skip if should not be quantized
if (!concat_op_desc->GetAttrIfExists<bool>("use_quantizer")) return; if (!concat_op_desc->GetAttrIfExists<bool>("use_quantizer")) {
LogQuantizationDisabled(concat_op);
return;
}
GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern); GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern);
...@@ -481,7 +514,10 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const { ...@@ -481,7 +514,10 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const {
auto* prior_box_op_desc = prior_box_op->Op(); auto* prior_box_op_desc = prior_box_op->Op();
// skip if should not be quantized // skip if should not be quantized
if (!prior_box_op_desc->GetAttrIfExists<bool>("use_quantizer")) return; if (!prior_box_op_desc->GetAttrIfExists<bool>("use_quantizer")) {
LogQuantizationDisabled(prior_box_op);
return;
}
GET_IR_NODE_FROM_SUBGRAPH(prior_box_input, prior_box_input, GET_IR_NODE_FROM_SUBGRAPH(prior_box_input, prior_box_input,
prior_box_pattern); prior_box_pattern);
...@@ -522,6 +558,7 @@ void CPUQuantizePass::QuantizeTranspose(Graph* graph) const { ...@@ -522,6 +558,7 @@ void CPUQuantizePass::QuantizeTranspose(Graph* graph) const {
// skip if should not be quantized // skip if should not be quantized
if (!transpose_op_desc->GetAttrIfExists<bool>("use_quantizer")) { if (!transpose_op_desc->GetAttrIfExists<bool>("use_quantizer")) {
LogQuantizationDisabled(transpose_op);
return; return;
} }
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, transpose_pattern); GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, transpose_pattern);
...@@ -576,6 +613,7 @@ void CPUQuantizePass::QuantizeReshape(Graph* graph) const { ...@@ -576,6 +613,7 @@ void CPUQuantizePass::QuantizeReshape(Graph* graph) const {
// skip if should not be quantized // skip if should not be quantized
if (!reshape_op_desc->GetAttrIfExists<bool>("use_quantizer")) { if (!reshape_op_desc->GetAttrIfExists<bool>("use_quantizer")) {
LogQuantizationDisabled(reshape_op);
return; return;
} }
GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, reshape_pattern); GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, reshape_pattern);
...@@ -628,6 +666,7 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { ...@@ -628,6 +666,7 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
// skip if should not be quantized // skip if should not be quantized
if (!matmul_op_desc->GetAttrIfExists<bool>("use_quantizer")) { if (!matmul_op_desc->GetAttrIfExists<bool>("use_quantizer")) {
LogQuantizationDisabled(matmul_op);
return; return;
} }
GET_IR_NODE_FROM_SUBGRAPH(prev_op_x, prev_op_x, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(prev_op_x, prev_op_x, matmul_pattern);
...@@ -649,10 +688,12 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { ...@@ -649,10 +688,12 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
bool is_x_unsigned{false}, is_y_unsigned{false}; bool is_x_unsigned{false}, is_y_unsigned{false};
auto input_x_scale = GetScaleValueForNode(matmul_in_x, &is_x_unsigned); auto input_x_scale = GetScaleValueForNode(matmul_in_x, &is_x_unsigned);
auto input_y_scale = GetScaleValueForNode(matmul_in_y, &is_y_unsigned); auto input_y_scale = GetScaleValueForNode(matmul_in_y, &is_y_unsigned);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(is_x_unsigned, is_y_unsigned,
is_x_unsigned, is_y_unsigned, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "Matmul inputs should have the same "
"Matmul inputs should have the same value of is_unsigned")); "attribute of signed/unsigned, but they "
"are different: x(%d), y(%d).",
is_x_unsigned, is_y_unsigned));
QuantizeInput(g, matmul_op, matmul_in_x, "X", input_x_scale, is_x_unsigned, QuantizeInput(g, matmul_op, matmul_in_x, "X", input_x_scale, is_x_unsigned,
"Scale_x"); "Scale_x");
QuantizeInput(g, matmul_op, matmul_in_y, "Y", input_y_scale, is_y_unsigned, QuantizeInput(g, matmul_op, matmul_in_y, "Y", input_y_scale, is_y_unsigned,
...@@ -676,12 +717,88 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { ...@@ -676,12 +717,88 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
PrettyLogDetail("--- quantized %d matmul ops", quantize_matmul_count); PrettyLogDetail("--- quantized %d matmul ops", quantize_matmul_count);
} }
void CPUQuantizePass::QuantizeElementwiseAdd(Graph* graph) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
elementwise_add_pattern(
pattern->NewNode(elementwise_add_pattern.elementwise_add_x_repr()),
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
int quantize_elementwise_add_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Quantize elementwise_add op";
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
auto* elementwise_add_op_desc = elementwise_add_op->Op();
// skip if should not be quantized
if (!elementwise_add_op_desc->GetAttrIfExists<bool>("use_quantizer")) {
LogQuantizationDisabled(elementwise_add_op);
return;
}
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
if (!AreScalesPresentForNodes(elementwise_add_op,
{elementwise_add_x, elementwise_add_y})) {
LogCannotQuantizeOp(elementwise_add_op);
return;
}
bool is_x_unsigned{false}, is_y_unsigned{false};
auto input_x_scale =
GetScaleValueForNode(elementwise_add_x, &is_x_unsigned);
auto input_y_scale =
GetScaleValueForNode(elementwise_add_y, &is_y_unsigned);
// TODO(sfraczek): add support for different signness
if (is_x_unsigned != is_y_unsigned) {
LogCannotQuantizeOp(elementwise_add_op,
"ElementwiseAdd inputs must be of the same type.");
return;
}
QuantizeInput(g, elementwise_add_op, elementwise_add_x, "X", input_x_scale,
is_x_unsigned, "Scale_x");
QuantizeInput(g, elementwise_add_op, elementwise_add_y, "Y", input_y_scale,
is_y_unsigned, "Scale_y");
// if quantization scale is missing for output tensor, return fp32 data
if (AreScalesPresentForNodes(elementwise_add_op, {elementwise_add_out})) {
bool is_output_unsigned{false};
auto output_scale =
GetScaleValueForNode(elementwise_add_out, &is_output_unsigned);
DequantizeOutput(g, elementwise_add_op, elementwise_add_out, "Out",
output_scale, is_output_unsigned, "Scale_out");
} else {
elementwise_add_op->Op()->SetAttr("force_fp32_output", true);
}
++quantize_elementwise_add_count;
};
gpd(graph, handler);
AddStatis(quantize_elementwise_add_count);
PrettyLogDetail("--- quantized %d elementwise_add ops",
quantize_elementwise_add_count);
}
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Quantizing the graph."; VLOG(3) << "Quantizing the graph.";
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
PADDLE_ENFORCE(param_scope()); PADDLE_ENFORCE_NOT_NULL(param_scope(), platform::errors::InvalidArgument(
"Scope cannot be nullptr."));
QuantizeConv(graph, false /* with_residual_data */); QuantizeConv(graph, false /* with_residual_data */);
QuantizeConv(graph, true /* with_residual_data */); QuantizeConv(graph, true /* with_residual_data */);
...@@ -692,6 +809,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -692,6 +809,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeFc(graph); QuantizeFc(graph);
QuantizeReshape(graph); QuantizeReshape(graph);
QuantizeMatmul(graph); QuantizeMatmul(graph);
QuantizeElementwiseAdd(graph);
} }
} // namespace ir } // namespace ir
......
...@@ -60,6 +60,8 @@ class CPUQuantizePass : public FusePassBase { ...@@ -60,6 +60,8 @@ class CPUQuantizePass : public FusePassBase {
void QuantizeMatmul(Graph* graph) const; void QuantizeMatmul(Graph* graph) const;
void QuantizeElementwiseAdd(Graph* graph) const;
void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name, void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name,
double scale_to_one, bool is_unsigned, double scale_to_one, bool is_unsigned,
std::string scale_attr_name = "") const; std::string scale_attr_name = "") const;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h" #include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -82,6 +83,14 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -82,6 +83,14 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetAttr("Scale_x", 1.0f); op->SetAttr("Scale_x", 1.0f);
op->SetAttr("Scale_y", 1.0f); op->SetAttr("Scale_y", 1.0f);
op->SetAttr("Scale_out", 1.0f); op->SetAttr("Scale_out", 1.0f);
} else if (type == "elementwise_add") {
op->SetInput("X", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("use_quantizer", use_quantizer);
op->SetAttr("Scale_x", 1.0f);
op->SetAttr("Scale_y", 1.0f);
op->SetAttr("Scale_out", 1.0f);
} }
} }
...@@ -95,7 +104,8 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, ...@@ -95,7 +104,8 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
void PreparePass(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog, void PreparePass(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog,
const std::initializer_list<std::string> variable_names, const std::initializer_list<std::string> variable_names,
int* original_nodes_num, int* current_nodes_num, int* original_nodes_num, int* current_nodes_num,
std::string var_without_scale = "") { std::string var_without_scale = "",
std::string var_signed = "") {
auto place = paddle::platform::CPUPlace(); auto place = paddle::platform::CPUPlace();
NaiveExecutor exe{place}; NaiveExecutor exe{place};
Scope scope; Scope scope;
...@@ -108,8 +118,7 @@ void PreparePass(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog, ...@@ -108,8 +118,7 @@ void PreparePass(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog,
tensor.Resize({1}); tensor.Resize({1});
auto* ptr = tensor.mutable_data<double>(place); auto* ptr = tensor.mutable_data<double>(place);
ptr[0] = 2.0; ptr[0] = 2.0;
(*scales)[v] = std::make_pair(v == var_signed, std::move(tensor));
(*scales)[v] = std::make_pair(false, std::move(tensor));
} }
(*graph)->SetNotOwned(kParamScopeAttr, &scope); (*graph)->SetNotOwned(kParamScopeAttr, &scope);
...@@ -387,7 +396,7 @@ static const std::initializer_list<std::string> variable_names_reshape = { ...@@ -387,7 +396,7 @@ static const std::initializer_list<std::string> variable_names_reshape = {
// c->Dropout->d // c->Dropout->d
ProgramDesc BuildProgramDescReshape() { ProgramDesc BuildProgramDescReshape() {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : variable_names_transpose) { for (auto& v : variable_names_reshape) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true); SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true);
...@@ -402,7 +411,7 @@ ProgramDesc BuildProgramDescReshape() { ...@@ -402,7 +411,7 @@ ProgramDesc BuildProgramDescReshape() {
// c->Dropout->d // c->Dropout->d
ProgramDesc BuildProgramDescReshapeBetweenNonQuantizedOp() { ProgramDesc BuildProgramDescReshapeBetweenNonQuantizedOp() {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : variable_names_transpose) { for (auto& v : variable_names_reshape) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
...@@ -491,7 +500,7 @@ static const std::initializer_list<std::string> variable_names_matmul = { ...@@ -491,7 +500,7 @@ static const std::initializer_list<std::string> variable_names_matmul = {
ProgramDesc BuildProgramDescMatmul() { ProgramDesc BuildProgramDescMatmul() {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : variable_names_transpose) { for (auto& v : variable_names_matmul) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true); SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true);
...@@ -504,7 +513,7 @@ ProgramDesc BuildProgramDescMatmul() { ...@@ -504,7 +513,7 @@ ProgramDesc BuildProgramDescMatmul() {
ProgramDesc BuildProgramDescMatmulNotQuantized() { ProgramDesc BuildProgramDescMatmulNotQuantized() {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : variable_names_transpose) { for (auto& v : variable_names_matmul) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "dropout", "Dropout", {"a"}, {"b"}, false); SetOp(&prog, "dropout", "Dropout", {"a"}, {"b"}, false);
...@@ -569,6 +578,97 @@ TEST(CpuQuantizePass, matmul_not_quantized) { ...@@ -569,6 +578,97 @@ TEST(CpuQuantizePass, matmul_not_quantized) {
MainTestMatmul(BuildProgramDescMatmulNotQuantized(), matmul_count, MainTestMatmul(BuildProgramDescMatmulNotQuantized(), matmul_count,
quant_count, dequant_count, added_nodes_count, 1.0f); quant_count, dequant_count, added_nodes_count, 1.0f);
} }
static const std::initializer_list<std::string> variable_names_elementwise_add =
{"a", "b", "c", "d", "e", "f"};
ProgramDesc BuildProgramDescElementwiseAdd() {
ProgramDesc prog;
for (auto& v : variable_names_elementwise_add) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true);
SetOp(&prog, "dequantize", "Dequantize2", {"c"}, {"d"}, true);
SetOp(&prog, "elementwise_add", "ElementwiseAdd", {"b", "d"}, {"e"}, true,
true);
SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, false);
return prog;
}
void MainTestElementwiseAdd(const ProgramDesc& prog, int elementwise_add_count,
int quant_count, int dequant_count,
int added_nodes_count, float scale,
bool output_scale_missing = false,
bool unsigned_and_signed_input = false) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num, current_nodes_num;
PreparePass(&graph, prog, variable_names_elementwise_add, &original_nodes_num,
&current_nodes_num, output_scale_missing ? "e" : "",
unsigned_and_signed_input ? "b" : "");
int quantize_nodes_count = 0;
int dequantize_nodes_count = 0;
int elementwise_add_nodes_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->Type() == "elementwise_add") {
elementwise_add_nodes_count++;
if (unsigned_and_signed_input) scale = 1.0f;
auto op_name = BOOST_GET_CONST(std::string, op->GetAttr("name"));
EXPECT_EQ(BOOST_GET_CONST(float, op->GetAttr("Scale_x")), scale)
<< "Scale_x for node '" + op_name + "'.";
EXPECT_EQ(BOOST_GET_CONST(float, op->GetAttr("Scale_y")), scale)
<< "Scale_y for node '" + op_name + "'.";
if (output_scale_missing) scale = 1.0;
EXPECT_EQ(BOOST_GET_CONST(float, op->GetAttr("Scale_out")), scale)
<< "Scale_out for node '" + op_name + "'.";
} else if (op->Type() == "quantize") {
quantize_nodes_count++;
} else if (op->Type() == "dequantize") {
dequantize_nodes_count++;
}
}
}
EXPECT_EQ(elementwise_add_nodes_count, elementwise_add_count);
EXPECT_EQ(quantize_nodes_count, quant_count);
EXPECT_EQ(dequantize_nodes_count, dequant_count);
EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num);
}
TEST(CpuQuantizePass, elementwise_add) {
int elementwise_add_count = 1;
int quant_count = 2;
int dequant_count = 3;
// 2 Quant + 2 IN + 1 DeQuant + 1 OUT
int added_nodes_count = 6;
MainTestElementwiseAdd(BuildProgramDescElementwiseAdd(),
elementwise_add_count, quant_count, dequant_count,
added_nodes_count, 2.0f * 127);
}
TEST(CpuQuantizePass, elementwise_add_output_scale_missing) {
int elementwise_add_count = 1;
int quant_count = 2;
int dequant_count = 2;
// 2 Quant + 2 IN
int added_nodes_count = 4;
MainTestElementwiseAdd(BuildProgramDescElementwiseAdd(),
elementwise_add_count, quant_count, dequant_count,
added_nodes_count, 2.0f * 127, true);
}
TEST(CpuQuantizePass, elementwise_add_unsigned_and_signed_input) {
int elementwise_add_count = 1;
int quant_count = 0;
int dequant_count = 2;
int added_nodes_count = 0;
MainTestElementwiseAdd(BuildProgramDescElementwiseAdd(),
elementwise_add_count, quant_count, dequant_count,
added_nodes_count, 2.0f * 127, false, true);
}
} // namespace } // namespace
} // namespace ir } // namespace ir
......
...@@ -75,7 +75,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash( ...@@ -75,7 +75,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
BOOST_GET_CONST(float, quant_op->Op()->GetAttr("Scale")); BOOST_GET_CONST(float, quant_op->Op()->GetAttr("Scale"));
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
nodes_keep_counter->find(dequant_out), nodes_keep_counter->end(), nodes_keep_counter->find(dequant_out), nodes_keep_counter->end(),
platform::errors::NotFound("The dequant output node is not found")); platform::errors::NotFound("The dequant output node is not found."));
// check if dequantize op should be kept or removed, decrease the counter // check if dequantize op should be kept or removed, decrease the counter
bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1; bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1;
...@@ -153,8 +153,9 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const { ...@@ -153,8 +153,9 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
any_op_output_name.empty(), true, any_op_output_name.empty(), true,
platform::errors::NotFound("Operator before requantize operator " platform::errors::NotFound("Operator before requantize operator(%s) "
"should have requantize input as output")); "should have requantize input as output.",
requant_in->Name()));
float requant_scale_out = float requant_scale_out =
BOOST_GET_CONST(float, requant_op->Op()->GetAttr("Scale_out")); BOOST_GET_CONST(float, requant_op->Op()->GetAttr("Scale_out"));
...@@ -195,10 +196,11 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const { ...@@ -195,10 +196,11 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
for (auto input_name : any_op->Op()->Input(name)) for (auto input_name : any_op->Op()->Input(name))
if (input_name == requant_out->Name()) any_op_input_name = name; if (input_name == requant_out->Name()) any_op_input_name = name;
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(any_op_input_name.empty(), true,
any_op_input_name.empty(), true, platform::errors::NotFound(
platform::errors::NotFound("The operator after requantize operator " "The operator after requantize operator(%s) "
"should have requantize output as input")); "should have requantize output as input.",
requant_out->Name()));
float requant_scale_in = float requant_scale_in =
boost::get<float>(requant_op->Op()->GetAttr("Scale_in")); boost::get<float>(requant_op->Op()->GetAttr("Scale_in"));
...@@ -206,11 +208,14 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const { ...@@ -206,11 +208,14 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
if (any_op->Op()->Type() == "matmul") if (any_op->Op()->Type() == "matmul")
scale_name = any_op_input_name == "X" ? "Scale_x" : "Scale_y"; scale_name = any_op_input_name == "X" ? "Scale_x" : "Scale_y";
PADDLE_ENFORCE_EQ(requant_op->Op()->GetAttrIfExists<float>("Scale_out"), PADDLE_ENFORCE_EQ(
any_op->Op()->GetAttrIfExists<float>(scale_name), requant_op->Op()->GetAttrIfExists<float>("Scale_out"),
platform::errors::InvalidArgument( any_op->Op()->GetAttrIfExists<float>(scale_name),
"The operator after requantize should have input " platform::errors::InvalidArgument(
"scale equal to requantize output scale")); "The operator after requantize should have input "
"scale(%f) equal to requantize output scale(%f).",
any_op->Op()->GetAttrIfExists<float>(scale_name),
requant_op->Op()->GetAttrIfExists<float>("Scale_out")));
any_op->Op()->SetAttr(scale_name, requant_scale_in); any_op->Op()->SetAttr(scale_name, requant_scale_in);
any_op->Op()->SetInput(any_op_input_name, any_op->Op()->SetInput(any_op_input_name,
std::vector<std::string>({requant_in->Name()})); std::vector<std::string>({requant_in->Name()}));
...@@ -286,8 +291,9 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { ...@@ -286,8 +291,9 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
auto* first_quant_out = first_quant_op->outputs[0]; auto* first_quant_out = first_quant_op->outputs[0];
float scale = first_quant_op->Op()->GetAttrIfExists<float>("Scale"); float scale = first_quant_op->Op()->GetAttrIfExists<float>("Scale");
PADDLE_ENFORCE_NE(scale, 0, platform::errors::InvalidArgument( PADDLE_ENFORCE_NE(scale, 0,
"Quantize scale should not be equal 0")); platform::errors::InvalidArgument(
"Quantize scale(%f) should not be equal 0.", scale));
for (int iter = prev_out->outputs.size() - 1; iter >= 0; iter--) { for (int iter = prev_out->outputs.size() - 1; iter >= 0; iter--) {
auto quant_op = prev_out->outputs[iter]; auto quant_op = prev_out->outputs[iter];
...@@ -304,8 +310,9 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { ...@@ -304,8 +310,9 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
last_op_input_name.empty(), true, last_op_input_name.empty(), true,
platform::errors::NotFound("Operator after quantize operator " platform::errors::NotFound("Operator after quantize operator(%s) "
"should has quantize output as input")); "should has quantize output as input.",
quant_out->Name()));
last_op->Op()->SetInput( last_op->Op()->SetInput(
last_op_input_name, last_op_input_name,
std::vector<std::string>({first_quant_out->Name()})); std::vector<std::string>({first_quant_out->Name()}));
...@@ -345,10 +352,12 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const { ...@@ -345,10 +352,12 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
PADDLE_ENFORCE_GT(dequant_scale, 0.0f, PADDLE_ENFORCE_GT(dequant_scale, 0.0f,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Dequantize scale should have positive value")); "Dequantize scale(%f) should have positive value.",
dequant_scale));
PADDLE_ENFORCE_GT(scale_scale, 0.0f, PADDLE_ENFORCE_GT(scale_scale, 0.0f,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Scale of scale op should have positive value")); "Scale(%f) of scale op should have positive value.",
scale_scale));
dequant_op->Op()->SetAttr("Scale", dequant_scale / scale_scale); dequant_op->Op()->SetAttr("Scale", dequant_scale / scale_scale);
dequant_op->Op()->SetOutput( dequant_op->Op()->SetOutput(
...@@ -367,8 +376,8 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const { ...@@ -367,8 +376,8 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, graph,
platform::errors::NotFound( platform::errors::InvalidArgument(
"The graph in function CPUQuantizeSquashPass::ApplyImpl is null")); "The graph in function CPUQuantizeSquashPass::ApplyImpl is null."));
FusePassBase::Init("cpu_quantize_squash_pass", graph); FusePassBase::Init("cpu_quantize_squash_pass", graph);
std::unordered_map<const Node*, int> nodes_keep_counter; std::unordered_map<const Node*, int> nodes_keep_counter;
......
...@@ -57,7 +57,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -57,7 +57,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
PADDLE_ENFORCE_EQ(inputs.size(), 2UL, PADDLE_ENFORCE_EQ(inputs.size(), 2UL,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The fc inputs should contain input and weights, but " "The fc inputs should contain input and weights, but "
"now the size of inputs is %d", "now the size of inputs is %d.",
inputs.size())); inputs.size()));
op->SetInput("W", {inputs[1]}); op->SetInput("W", {inputs[1]});
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
......
...@@ -19,14 +19,17 @@ namespace paddle { ...@@ -19,14 +19,17 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
#define GET_NODE(id, pattern) \ #define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetrieveNode(#id)), \ PADDLE_ENFORCE_NE(subgraph.count(pattern.RetrieveNode(#id)), 0, \
"pattern has no Node called %s", #id); \ platform::errors::InvalidArgument( \
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \ "Pattern has no Node called %s.", #id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL( \
id, platform::errors::InvalidArgument("Subgraph has no node %s.", #id));
void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const { void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("depthwise_conv_mkldnn_pass", graph); FusePassBase::Init("depthwise_conv_mkldnn_pass", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
......
...@@ -66,17 +66,17 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -66,17 +66,17 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
return; return;
} }
VLOG(3) << "DNNL Inplace op(" << current_op->id() << ") " VLOG(3) << "oneDNN Inplace op(" << current_op->id() << ") "
<< "Curr Node In: " << current_op_in->Name() << "Curr Node In: " << current_op_in->Name()
<< " Curr Node out: " << current_op_out->Name(); << " Curr Node out: " << current_op_out->Name();
VLOG(3) << "DNNL Inplace next op(" << next_op->id() << ") " VLOG(3) << "oneDNN Inplace next op(" << next_op->id() << ") "
<< " next Node out: " << next_op_out->Name(); << " next Node out: " << next_op_out->Name();
auto inputs = current_op->Op()->Inputs(); auto inputs = current_op->Op()->Inputs();
auto outputs = current_op->Op()->Outputs(); auto outputs = current_op->Op()->Outputs();
auto in_to_outs = infer_inplace(false); // strictly no CUDA for MKL-DNN auto in_to_outs = infer_inplace(false); // strictly no CUDA for MKL-DNN
VLOG(3) << "DNNL InplaceInferer op(" << current_op->id() << ") " VLOG(3) << "oneDNN InplaceInferer op(" << current_op->id() << ") "
<< in_to_outs.begin()->first << ": " << in_to_outs.begin()->first << ": "
<< inputs[in_to_outs.begin()->first][0] << " " << inputs[in_to_outs.begin()->first][0] << " "
<< in_to_outs.begin()->second << ": " << in_to_outs.begin()->second << ": "
...@@ -85,7 +85,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -85,7 +85,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
auto inplace_input_vec = inputs[in_to_outs.begin()->first]; auto inplace_input_vec = inputs[in_to_outs.begin()->first];
if (std::find(inplace_input_vec.begin(), inplace_input_vec.end(), if (std::find(inplace_input_vec.begin(), inplace_input_vec.end(),
current_op_in->Name()) == inplace_input_vec.end()) { current_op_in->Name()) == inplace_input_vec.end()) {
VLOG(3) << "DNNL in-place pass SKIP pattern "; VLOG(3) << "oneDNN in-place pass SKIP pattern ";
return; return;
} }
...@@ -93,7 +93,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -93,7 +93,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
// is used anywhere else apart from inplaced op // is used anywhere else apart from inplaced op
auto input_consumers = current_op_in->outputs; auto input_consumers = current_op_in->outputs;
if (input_consumers.size() > 1) { if (input_consumers.size() > 1) {
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot " VLOG(3) << "oneDNN in-place pass FAIL: in-place var cannot "
"be an input to multiple operators"; "be an input to multiple operators";
return; return;
} else { } else {
...@@ -106,7 +106,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -106,7 +106,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
if ((n->id() != current_op_in->id()) && if ((n->id() != current_op_in->id()) &&
(n->id() != next_op_out->id()) && (n->id() != next_op_out->id()) &&
(n->Name() == current_op_in->Name())) { (n->Name() == current_op_in->Name())) {
VLOG(3) << "DNNL in-place pass FAIL var used in diffrent part of " VLOG(3) << "oneDNN in-place pass FAIL var used in diffrent part of "
"graph "; "graph ";
return; return;
} }
...@@ -122,7 +122,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -122,7 +122,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
original_output_names[current_op->Name() + current_op_in->Name()] = original_output_names[current_op->Name() + current_op_in->Name()] =
current_op_out->Name(); current_op_out->Name();
} else { } else {
VLOG(3) << "DNNL Inplace: Current op already inplaced! "; VLOG(3) << "oneDNN Inplace: Current op already inplaced! ";
} }
// It may be that next op is reusing some of vars, we need to // It may be that next op is reusing some of vars, we need to
...@@ -133,7 +133,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -133,7 +133,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
if ((n_op_infer_inplace == nullptr)) { if ((n_op_infer_inplace == nullptr)) {
for (auto& m : n->outputs) { for (auto& m : n->outputs) {
if (m->Name() == current_op_in->Name()) { if (m->Name() == current_op_in->Name()) {
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot " VLOG(3) << "oneDNN in-place pass FAIL: in-place var cannot "
"be an output to non-inplaced next op"; "be an output to non-inplaced next op";
return; return;
} }
...@@ -173,7 +173,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -173,7 +173,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
(std::find(next_op_inplace_inputs.begin(), (std::find(next_op_inplace_inputs.begin(),
next_op_inplace_inputs.end(), next_op_inplace_inputs.end(),
original_name) != next_op_inplace_inputs.end())) { original_name) != next_op_inplace_inputs.end())) {
VLOG(3) << "DNNL InPlace: Next Op is in-placed , updating its " VLOG(3) << "oneDNN InPlace: Next Op is in-placed , updating its "
"input " "input "
"and output var!"; "and output var!";
next_op->Op()->SetOutput( next_op->Op()->SetOutput(
...@@ -190,10 +190,24 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -190,10 +190,24 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
next_op->Op()->RenameInput(original_name, current_op_out->Name()); next_op->Op()->RenameInput(original_name, current_op_out->Name());
found_inplace_count++; found_inplace_count++;
VLOG(3) << "DNNL InPlace applied!"; VLOG(3) << "oneDNN InPlace applied!";
}; };
gpd(graph, handler); // TODO(jczaja): inplace pass does not influece ops inside block ops
auto should_inplace = [&](Graph* g) {
std::unordered_set<std::string> unwanted_ops(
{"conditional_block", "While", "while_loop"});
for (auto& node : g->Nodes()) {
if (node->IsOp() &&
unwanted_ops.find(node->Name()) != unwanted_ops.end()) {
VLOG(3) << "oneDNN InPlace FAILED: unsupported op: " << node->Name();
return false;
}
}
return true;
};
if (should_inplace(graph)) gpd(graph, handler);
} }
} // namespace ir } // namespace ir
......
...@@ -46,12 +46,15 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -46,12 +46,15 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const {
if (scale_op->Op()->GetAttrIfExists<float>("bias") == 0.0) { if (scale_op->Op()->GetAttrIfExists<float>("bias") == 0.0) {
auto matmul_alpha = matmul_op->Op()->GetAttrIfExists<float>("alpha"); auto matmul_alpha = matmul_op->Op()->GetAttrIfExists<float>("alpha");
auto scale_scale = scale_op->Op()->GetAttrIfExists<float>("scale"); auto scale_scale = scale_op->Op()->GetAttrIfExists<float>("scale");
PADDLE_ENFORCE_GT(matmul_alpha, 0.0f, PADDLE_ENFORCE_GT(
platform::errors::InvalidArgument( matmul_alpha, 0.0f,
"Alpha of matmul op should have positive value")); platform::errors::InvalidArgument(
"Alpha(%f) of matmul op should have positive value.",
matmul_alpha));
PADDLE_ENFORCE_GT(scale_scale, 0.0f, PADDLE_ENFORCE_GT(scale_scale, 0.0f,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Scale of scale op should have positive value")); "Scale(%f) of scale op should have positive value.",
scale_scale));
std::string matmul_op_input_name; std::string matmul_op_input_name;
for (auto name : matmul_op->Op()->InputNames()) for (auto name : matmul_op->Op()->InputNames())
...@@ -60,8 +63,9 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -60,8 +63,9 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
matmul_op_input_name.empty(), true, matmul_op_input_name.empty(), true,
platform::errors::NotFound("Operator after scale operator " platform::errors::NotFound("Operator after scale operator(%s) "
"should have scale output as input")); "should have scale output as input.",
scale_out->Name()));
matmul_op->Op()->SetAttr("alpha", matmul_alpha * scale_scale); matmul_op->Op()->SetAttr("alpha", matmul_alpha * scale_scale);
matmul_op->Op()->SetInput(matmul_op_input_name, matmul_op->Op()->SetInput(matmul_op_input_name,
std::vector<std::string>({scale_in->Name()})); std::vector<std::string>({scale_in->Name()}));
......
...@@ -45,7 +45,9 @@ class AllReduceDepsPass : public ir::Pass { ...@@ -45,7 +45,9 @@ class AllReduceDepsPass : public ir::Pass {
for (size_t i = 0; i < all_reduce_op_handles.size(); ++i) { for (size_t i = 0; i < all_reduce_op_handles.size(); ++i) {
auto op_handle = auto op_handle =
dynamic_cast<details::NCCLOpHandleBase*>(all_reduce_op_handles[i]); dynamic_cast<details::NCCLOpHandleBase*>(all_reduce_op_handles[i]);
PADDLE_ENFORCE(op_handle, "op_handle must be NCCLOpHandleBase"); PADDLE_ENFORCE_NOT_NULL(op_handle,
platform::errors::InvalidArgument(
"Op handle must be NCCLOpHandleBase."));
op_handle->SetRunEnv(i, use_hierarchical_allreduce); op_handle->SetRunEnv(i, use_hierarchical_allreduce);
} }
#endif #endif
...@@ -95,7 +97,9 @@ class AllReduceDepsPass : public ir::Pass { ...@@ -95,7 +97,9 @@ class AllReduceDepsPass : public ir::Pass {
} }
} }
PADDLE_ENFORCE_NE(next_ready_ops.size(), 0, "There maybe have a cycle."); PADDLE_ENFORCE_NE(
next_ready_ops.size(), 0,
platform::errors::InvalidArgument("There may be a cycle."));
ready_ops.clear(); ready_ops.clear();
std::swap(ready_ops, next_ready_ops); std::swap(ready_ops, next_ready_ops);
GetSortedAllReduceOps(ready_ops, &all_reduce_op_handles); GetSortedAllReduceOps(ready_ops, &all_reduce_op_handles);
...@@ -122,18 +126,25 @@ class AllReduceDepsPass : public ir::Pass { ...@@ -122,18 +126,25 @@ class AllReduceDepsPass : public ir::Pass {
// NOTE(zcd): For distributed training, it is important to keep the order of // NOTE(zcd): For distributed training, it is important to keep the order of
// allReduce on each node consistent. Otherwise, hang may occur. // allReduce on each node consistent. Otherwise, hang may occur.
// Sort the current_all_reduce_op_handles according to the name of input. // Sort the current_all_reduce_op_handles according to the name of input.
sort(current_all_reduce_op_handles.begin(), sort(
current_all_reduce_op_handles.end(), current_all_reduce_op_handles.begin(),
[](const details::OpHandleBase* left, current_all_reduce_op_handles.end(),
const details::OpHandleBase* right) -> bool { [](const details::OpHandleBase* left,
auto left_in_vars = const details::OpHandleBase* right) -> bool {
details::DynamicCast<details::VarHandle>(left->Inputs()); auto left_in_vars =
auto right_in_vars = details::DynamicCast<details::VarHandle>(left->Inputs());
details::DynamicCast<details::VarHandle>(right->Inputs()); auto right_in_vars =
PADDLE_ENFORCE_GT(left_in_vars.size(), 0); details::DynamicCast<details::VarHandle>(right->Inputs());
PADDLE_ENFORCE_GT(right_in_vars.size(), 0); PADDLE_ENFORCE_GT(left_in_vars.size(), 0,
return left_in_vars[0]->Name() > right_in_vars[0]->Name(); platform::errors::InvalidArgument(
}); "OpHandle(%s) inputs size must greater than 0.",
left->Name()));
PADDLE_ENFORCE_GT(right_in_vars.size(), 0,
platform::errors::InvalidArgument(
"OpHandle(%s) inputs size must greater than 0.",
right->Name()));
return left_in_vars[0]->Name() > right_in_vars[0]->Name();
});
all_reduce_op_handles->insert(all_reduce_op_handles->end(), all_reduce_op_handles->insert(all_reduce_op_handles->end(),
current_all_reduce_op_handles.begin(), current_all_reduce_op_handles.begin(),
...@@ -170,7 +181,10 @@ class AllReduceDepsPass : public ir::Pass { ...@@ -170,7 +181,10 @@ class AllReduceDepsPass : public ir::Pass {
break; break;
} }
} }
PADDLE_ENFORCE(find_valid_input, "Doesn't find valid input."); PADDLE_ENFORCE_EQ(
find_valid_input, true,
platform::errors::NotFound(
"In OpHandle(%s) Doesn't find valid input.", op->Name()));
} }
VLOG(10) << out2.str(); VLOG(10) << out2.str();
if (grads_of_stale_program != all_reduce_op_handles.size()) { if (grads_of_stale_program != all_reduce_op_handles.size()) {
......
...@@ -179,9 +179,10 @@ class BackWardOpDepsPass : public ir::Pass { ...@@ -179,9 +179,10 @@ class BackWardOpDepsPass : public ir::Pass {
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
auto backward_vars = details::GetOpRoleVarsOrEmpty(op_desc); auto backward_vars = details::GetOpRoleVarsOrEmpty(op_desc);
PADDLE_ENFORCE_EQ(node->IsWrappedBy<details::OpHandleBase>(), true, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( node->IsWrappedBy<details::OpHandleBase>(), true,
"Node must be wrapped by OpHandleBase")); platform::errors::InvalidArgument(
"Node(%s) must be wrapped by OpHandleBase.", node->Name()));
backward_op_handles->emplace_back(&node->Wrapper<details::OpHandleBase>()); backward_op_handles->emplace_back(&node->Wrapper<details::OpHandleBase>());
......
...@@ -64,9 +64,10 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -64,9 +64,10 @@ class FuseAllReduceOpPass : public ir::Pass {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
all_reduce_ops.size(), grads.size(), all_reduce_ops.size(), grads.size(),
platform::errors::Unimplemented( platform::errors::Unimplemented(
"The number of all_reduce OpHandle is not equal to the " "The number of all_reduce OpHandle(%d) is not equal to the "
"number of grads. Maybe some gradients are sparse type, " "number of grads(%d). Maybe some gradients are sparse type, "
"it is not supported currently.")); "it is not supported currently.",
all_reduce_ops.size(), grads.size()));
auto &group_params_grads = graph->Get<details::GroupParamsAndGrads>( auto &group_params_grads = graph->Get<details::GroupParamsAndGrads>(
details::kGroupParamsAndDenseGrads); details::kGroupParamsAndDenseGrads);
...@@ -79,7 +80,10 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -79,7 +80,10 @@ class FuseAllReduceOpPass : public ir::Pass {
for (auto &group_p_g : group_params_grads) { for (auto &group_p_g : group_params_grads) {
size_t group_size = group_p_g.size(); size_t group_size = group_p_g.size();
PADDLE_ENFORCE_GT(group_size, static_cast<size_t>(0)); PADDLE_ENFORCE_GT(
group_size, static_cast<size_t>(0),
platform::errors::InvalidArgument(
"Parameter and Parameter@grad in one group, must not be empty."));
std::vector<ir::Node *> group_all_reduce_ops; std::vector<ir::Node *> group_all_reduce_ops;
group_all_reduce_ops.reserve(group_size); group_all_reduce_ops.reserve(group_size);
for (auto &p_g : group_p_g) { for (auto &p_g : group_p_g) {
...@@ -103,26 +107,40 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -103,26 +107,40 @@ class FuseAllReduceOpPass : public ir::Pass {
all_reduce_ops.reserve(grads.size()); all_reduce_ops.reserve(grads.size());
for (auto &node : result.Nodes()) { for (auto &node : result.Nodes()) {
if (node->IsOp()) { if (node->IsOp()) {
PADDLE_ENFORCE(node->IsWrappedBy<details::OpHandleBase>()); PADDLE_ENFORCE_EQ(
node->IsWrappedBy<details::OpHandleBase>(), true,
platform::errors::InvalidArgument(
"Op Node(%s) should Wrapped by OpHandleBase.", node->Name()));
auto *all_reduce_op_handle = dynamic_cast<details::AllReduceOpHandle *>( auto *all_reduce_op_handle = dynamic_cast<details::AllReduceOpHandle *>(
&node->Wrapper<details::OpHandleBase>()); &node->Wrapper<details::OpHandleBase>());
if (all_reduce_op_handle) { if (all_reduce_op_handle) {
#if defined(PADDLE_WITH_DGC) #if defined(PADDLE_WITH_DGC)
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
all_reduce_op_handle->Name(), "sparse_all_reduce", all_reduce_op_handle->Name(), "sparse_all_reduce",
"DGC doesn't support fuse for now, if you want to use DGC " platform::errors::InvalidArgument(
"you need set strategy.fuse_all_reduce_ops = False."); "DGC doesn't support fuse for now, if you want to use DGC "
"you need set strategy.fuse_all_reduce_ops = False."));
#endif #endif
auto inputs = details::DynamicCast<details::VarHandle>( auto inputs = details::DynamicCast<details::VarHandle>(
all_reduce_op_handle->Inputs()); all_reduce_op_handle->Inputs());
PADDLE_ENFORCE_EQ(inputs.size(), num_place); PADDLE_ENFORCE_EQ(inputs.size(), num_place,
platform::errors::InvalidArgument(
"The input size(%d) of all reduce op must "
"equal to place cnt(%d)!",
inputs.size(), num_place));
// The inputs' name should be the same. // The inputs' name should be the same.
auto &grad_name = inputs[0]->name(); auto &grad_name = inputs[0]->name();
for (size_t i = 1; i < inputs.size(); ++i) { for (size_t i = 1; i < inputs.size(); ++i) {
PADDLE_ENFORCE_EQ(inputs[i]->name(), grad_name, PADDLE_ENFORCE_EQ(
"The input name should be the same."); inputs[i]->name(), grad_name,
platform::errors::InvalidArgument(
"The input name should be the same.diff name: %s %s.",
inputs[i]->name(), grad_name));
} }
PADDLE_ENFORCE_NE(grads.count(grad_name), static_cast<size_t>(0)); PADDLE_ENFORCE_NE(
grads.count(grad_name), static_cast<size_t>(0),
platform::errors::InvalidArgument(
"Parameter@grad(%s) must in grad set.", grad_name));
all_reduce_ops.emplace(grad_name, node); all_reduce_ops.emplace(grad_name, node);
} }
} }
......
...@@ -24,7 +24,10 @@ namespace ir { ...@@ -24,7 +24,10 @@ namespace ir {
class SSAGraghBuilderWithChecker : public ir::Pass { class SSAGraghBuilderWithChecker : public ir::Pass {
protected: protected:
void ApplyImpl(ir::Graph *graph) const override { void ApplyImpl(ir::Graph *graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph)); PADDLE_ENFORCE_EQ(
IsValidGraph(graph), true,
platform::errors::InvalidArgument(
"In SSAGraghBuilderWithChecker, invalid Graph input."));
} }
bool IsValidGraph(const ir::Graph *graph) const { bool IsValidGraph(const ir::Graph *graph) const {
......
...@@ -163,7 +163,13 @@ void MultiDevSSAGraphBuilderBase::Init() const { ...@@ -163,7 +163,13 @@ void MultiDevSSAGraphBuilderBase::Init() const {
nccl_ctxs_ = multi_nccl_ctxs_->DefaultFlatCtx(); nccl_ctxs_ = multi_nccl_ctxs_->DefaultFlatCtx();
} }
#endif #endif
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(
places_.size(), local_scopes_.size(),
platform::errors::InvalidArgument(
"Places size and LocalScopes not equal "
"Places size(%d), LocalScopes size(%d) "
"If use multi devices, Places size must equas to LocalScopes size.",
places_.size(), local_scopes_.size()));
} }
void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const { void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
...@@ -500,7 +506,11 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result, ...@@ -500,7 +506,11 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
SetCommunicationContext(op_handle, places_[i]); SetCommunicationContext(op_handle, places_[i]);
auto &vars = result->Get<details::GraphVars>(details::kGraphVars)[i][og]; auto &vars = result->Get<details::GraphVars>(details::kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE_EQ(vars.empty(), false,
platform::errors::InvalidArgument(
"Can not find Var(%s) in Place[%d] "
"Paddle Can not add AllReduce OP for Var(%s).",
og, i, og));
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad); op_handle->AddInput(prev_grad);
VLOG(10) << "all_reduce_op_handle add input " << prev_grad->DebugString(); VLOG(10) << "all_reduce_op_handle add input " << prev_grad->DebugString();
...@@ -566,7 +576,11 @@ details::VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp( ...@@ -566,7 +576,11 @@ details::VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
auto &vars = result->Get<details::GraphVars>(details::kGraphVars)[i][og]; auto &vars = result->Get<details::GraphVars>(details::kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE_EQ(vars.empty(), false,
platform::errors::InvalidArgument(
"Can not find Var(%s) in Place[%d] "
"Paddle Can not add Reduce OP for Var(%s).",
og, i, og));
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad); op_handle->AddInput(prev_grad);
} }
...@@ -590,7 +604,11 @@ bool MultiDevSSAGraphBuilderBase::IsScaleLossOp(ir::Node *node) const { ...@@ -590,7 +604,11 @@ bool MultiDevSSAGraphBuilderBase::IsScaleLossOp(ir::Node *node) const {
bool MultiDevSSAGraphBuilderBase::IsSparseGradient( bool MultiDevSSAGraphBuilderBase::IsSparseGradient(
const std::string &og) const { const std::string &og) const {
PADDLE_ENFORCE(all_vars_.count(og) != 0); PADDLE_ENFORCE_NE(all_vars_.count(og), 0,
platform::errors::InvalidArgument(
"Can not find Var(%s) in VarDescs "
"Paddle Can not add Collective OP for Var(%s).",
og, og));
return all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS; return all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS;
} }
...@@ -641,10 +659,20 @@ int BalanceVarSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { ...@@ -641,10 +659,20 @@ int BalanceVarSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
std::vector<std::string>, std::vector<std::string>,
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U); PADDLE_ENFORCE_EQ(
param_grad.size(), 2U,
platform::errors::InvalidArgument(
"In Node %s, the size of attribute %s must be 2, include Parameter "
"and Parameter@Grad.",
node->Name(), OpProtoAndCheckerMaker::OpRoleVarAttrName()));
int dev_id = GetVarDeviceID(param_grad[1]); int dev_id = GetVarDeviceID(param_grad[1]);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]", PADDLE_ENFORCE_NE(dev_id, -1, platform::errors::NotFound(
node->Op()->Type(), param_grad[0], param_grad[1]); "Can not find Device ID, for NodeName:%s, "
"NodeType:%s, Param:%s, Param@Grad:%s"
"For this fault, you can consult the "
"Paddle technical personnel for answer ",
node->Name(), node->Op()->Type(),
param_grad[0], param_grad[1]));
return dev_id; return dev_id;
} }
...@@ -654,10 +682,16 @@ size_t BalanceVarSSAGraphBuilder::GetAppropriateDeviceID( ...@@ -654,10 +682,16 @@ size_t BalanceVarSSAGraphBuilder::GetAppropriateDeviceID(
for (auto var_name : var_names) { for (auto var_name : var_names) {
if (all_vars_.find(var_name) == all_vars_.end()) continue; if (all_vars_.find(var_name) == all_vars_.end()) continue;
auto var_desc = all_vars_.at(var_name); auto var_desc = all_vars_.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc); PADDLE_ENFORCE_NOT_NULL(var_desc,
platform::errors::NotFound(
"Can not find Var(%s) in Var Desc.", var_name));
auto dim = framework::make_ddim(var_desc->GetShape()); auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim); int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GT(numel, 0); PADDLE_ENFORCE_GT(numel, 0,
platform::errors::InvalidArgument(
"The numel of Var(%s) must greater than 0"
"Please check your code,about Var(%s) Shape.",
var_name, var_name));
numel_sum += numel; numel_sum += numel;
} }
...@@ -736,7 +770,12 @@ int ReduceSSAGraphBuilder::GetOpDeviceID( ...@@ -736,7 +770,12 @@ int ReduceSSAGraphBuilder::GetOpDeviceID(
std::vector<std::string>, std::vector<std::string>,
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U); PADDLE_ENFORCE_EQ(
param_grad.size(), 2U,
platform::errors::InvalidArgument(
"In Node %s, The size of attribute %s must be 2, include Parameter "
"and Parameter@Grad.",
node->Name(), OpProtoAndCheckerMaker::OpRoleVarAttrName()));
int dev_id = GetVarDeviceID(param_grad[1]); int dev_id = GetVarDeviceID(param_grad[1]);
if (dev_id == -1) { if (dev_id == -1) {
...@@ -798,7 +837,12 @@ std::vector<ir::Node *> ReduceSSAGraphBuilder::SortForReduceMode( ...@@ -798,7 +837,12 @@ std::vector<ir::Node *> ReduceSSAGraphBuilder::SortForReduceMode(
} }
} }
PADDLE_ENFORCE_EQ(sorted_ops.size(), topo_ops.size()); PADDLE_ENFORCE_EQ(sorted_ops.size(), topo_ops.size(),
platform::errors::InvalidArgument(
"Sorted ops calc error!"
"The result for sorted ops size(%d) must be "
"equal to topo ops size(%d).",
sorted_ops.size(), topo_ops.size()));
ResetState(); ResetState();
return sorted_ops; return sorted_ops;
...@@ -820,14 +864,23 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, ...@@ -820,14 +864,23 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
bool insert_op = false; bool insert_op = false;
if (OpHaveRole(*node, OpRole::kRPC)) { if (OpHaveRole(*node, OpRole::kRPC)) {
int op_dev_id = CreateRPCOp(result, node); int op_dev_id = CreateRPCOp(result, node);
PADDLE_ENFORCE(op_dev_id != -1, PADDLE_ENFORCE_NE(op_dev_id, -1, platform::errors::InvalidArgument(
"Can not schedule the RPC operator to the right place."); "Can not schedule the RPC operator to "
"the right place. NodeName:%s.",
node->Name()));
if (node->Op()->Type() == "recv") { if (node->Op()->Type() == "recv") {
auto recv_vars_attr = auto recv_vars_attr =
BOOST_GET_CONST(std::vector<std::string>, BOOST_GET_CONST(std::vector<std::string>,
node->Op()->GetNullableAttr( node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName())); OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE(recv_vars_attr.size() == 2UL); // [parameter, gradient] PADDLE_ENFORCE_EQ(
recv_vars_attr.size(), 2UL,
platform::errors::InvalidArgument(
"In Node %s, the size of attribute %s must be 2, include "
"Parameter and Parameter@Grad.",
node->Name(),
OpProtoAndCheckerMaker::OpRoleVarAttrName())); // [parameter,
// gradient]
if (recv_vars_attr[0].find(".block") == std::string::npos) { if (recv_vars_attr[0].find(".block") == std::string::npos) {
bcast_var_name_set_[op_dev_id].emplace(recv_vars_attr[0]); bcast_var_name_set_[op_dev_id].emplace(recv_vars_attr[0]);
} }
...@@ -879,8 +932,9 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { ...@@ -879,8 +932,9 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(node->inputs[0]->Name()); op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]), PADDLE_ENFORCE_EQ(ir::IsControlDepVar(*node->inputs[0]), false,
"This hack no longer holds, please fix."); platform::errors::InvalidArgument(
"This hack no longer holds, please fix."));
// the variable name which contains .block means it was split by // the variable name which contains .block means it was split by
// split_byref op // split_byref op
if (strategy_.reduce_ == if (strategy_.reduce_ ==
...@@ -893,7 +947,12 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { ...@@ -893,7 +947,12 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
auto send_param_grad = BOOST_GET_CONST( auto send_param_grad = BOOST_GET_CONST(
std::vector<std::string>, std::vector<std::string>,
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(send_param_grad.size(), 2U); PADDLE_ENFORCE_EQ(
send_param_grad.size(), 2U,
platform::errors::InvalidArgument(
"In Node %s, the size of attribute %s must be 2, include "
"Parameter and Parameter@Grad.",
node->Name(), OpProtoAndCheckerMaker::OpRoleVarAttrName()));
op_dev_id = GetAppropriateDeviceID({send_param_grad[1]}); op_dev_id = GetAppropriateDeviceID({send_param_grad[1]});
VLOG(10) << "send grad " << input_var_names[0] << " origin " VLOG(10) << "send grad " << input_var_names[0] << " origin "
<< send_param_grad[1] << " place: " << op_dev_id; << send_param_grad[1] << " place: " << op_dev_id;
...@@ -926,9 +985,10 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { ...@@ -926,9 +985,10 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
op_dev_id = 0; op_dev_id = 0;
} }
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", PADDLE_ENFORCE_NE(
node->Op()->Type()); op_dev_id, -1,
platform::errors::NotFound("Can not find the right place for rpc op: %s.",
node->Op()->Type()));
// Create fetch_barrier op handle to enable output on all devices. // Create fetch_barrier op handle to enable output on all devices.
// **NOTE** fetch_barrier should output variables list same as recv op does. // **NOTE** fetch_barrier should output variables list same as recv op does.
if (node->Op()->Type() == "fetch_barrier") { if (node->Op()->Type() == "fetch_barrier") {
...@@ -956,7 +1016,10 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { ...@@ -956,7 +1016,10 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
int outvar_dev_id = op_dev_id; int outvar_dev_id = op_dev_id;
if (node->Op()->Type() == "fetch_barrier") { if (node->Op()->Type() == "fetch_barrier") {
outvar_dev_id = GetVarDeviceID(output->Name()); outvar_dev_id = GetVarDeviceID(output->Name());
PADDLE_ENFORCE_NE(outvar_dev_id, -1, "output name %s", output->Name()); PADDLE_ENFORCE_NE(outvar_dev_id, -1,
platform::errors::NotFound(
"Can not find the right place for the var: %s.",
output->Name()));
} }
p = places_[outvar_dev_id]; p = places_[outvar_dev_id];
ir::Node *new_node = nullptr; ir::Node *new_node = nullptr;
...@@ -1007,13 +1070,14 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -1007,13 +1070,14 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
} else { } else {
LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type(); LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
PADDLE_THROW( PADDLE_THROW(
"the distribute training related op should be in [split_byref, " platform::errors::Unimplemented("The distribute training related op "
"concat]."); "should be in [split_byref, concat]."));
} }
PADDLE_ENFORCE(op_dev_id != -1, PADDLE_ENFORCE_NE(op_dev_id, -1,
"can not find right place for distributed op: %s", platform::errors::NotFound(
node->Op()->Type()); "Can not find right place for distributed op: %s.",
node->Op()->Type()));
CreateComputationalOp(result, node, op_dev_id); CreateComputationalOp(result, node, op_dev_id);
return op_dev_id; return op_dev_id;
......
...@@ -28,7 +28,10 @@ class SSAGraghBuilderWithPrinterPass : public ir::Pass { ...@@ -28,7 +28,10 @@ class SSAGraghBuilderWithPrinterPass : public ir::Pass {
void ApplyImpl(ir::Graph *graph) const override { void ApplyImpl(ir::Graph *graph) const override {
std::unique_ptr<std::ostream> fout( std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath))); new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE_EQ(
fout->good(), true,
platform::errors::Unavailable("Open file fail! kGraphvizPath = %s.",
Get<std::string>(kGraphvizPath)));
if (Has("graph_printer")) { if (Has("graph_printer")) {
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout); Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
} else { } else {
......
...@@ -54,11 +54,16 @@ class SequentialExecutionPass : public ir::Pass { ...@@ -54,11 +54,16 @@ class SequentialExecutionPass : public ir::Pass {
if (!node->IsOp()) continue; if (!node->IsOp()) continue;
std::unordered_set<ir::Node *> preceding_ops; std::unordered_set<ir::Node *> preceding_ops;
for (auto *in : node->inputs) { for (auto *in : node->inputs) {
PADDLE_ENFORCE(in->IsVar(), PADDLE_ENFORCE_EQ(
"Preceding Node of Op Nodes must be Var Node"); in->IsVar(), true,
platform::errors::InvalidArgument(
"Preceding Node(%s) of Op Nodes must be Var Node.",
in->Name()));
if (in->inputs.empty()) continue; if (in->inputs.empty()) continue;
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp(), PADDLE_ENFORCE_EQ((in->inputs.size() == 1 && in->inputs[0]->IsOp()),
"Preceding Op Node of Var Node must be unique"); true,
platform::errors::InvalidArgument(
"Preceding Op Node of Var Node must be unique."));
preceding_ops.insert(in->inputs[0]); preceding_ops.insert(in->inputs[0]);
pending_ops[in->inputs[0]].insert(node); pending_ops[in->inputs[0]].insert(node);
} }
...@@ -72,15 +77,18 @@ class SequentialExecutionPass : public ir::Pass { ...@@ -72,15 +77,18 @@ class SequentialExecutionPass : public ir::Pass {
ir::Node *found_node = nullptr; ir::Node *found_node = nullptr;
for (auto *node : ready_ops) { for (auto *node : ready_ops) {
if (IsSameOpDesc(op_desc, node->Op())) { if (IsSameOpDesc(op_desc, node->Op())) {
PADDLE_ENFORCE(found_node == nullptr, PADDLE_ENFORCE_EQ(
"Found multiple op_desc in graph: %s", found_node, nullptr,
op_desc->Type()); platform::errors::InvalidArgument(
"Found multiple op_desc in graph: %s.", op_desc->Type()));
found_node = node; found_node = node;
} }
} }
PADDLE_ENFORCE_NOT_NULL(found_node, "Cannot find op_desc in graph: %s", PADDLE_ENFORCE_NOT_NULL(
op_desc->Type()); found_node,
platform::errors::NotFound("Cannot find op_desc in graph: %s.",
op_desc->Type()));
for (auto *pending_op : pending_ops[found_node]) { for (auto *pending_op : pending_ops[found_node]) {
if (--op_deps.at(pending_op) == 0) { if (--op_deps.at(pending_op) == 0) {
ready_ops.insert(pending_op); ready_ops.insert(pending_op);
......
...@@ -45,13 +45,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -45,13 +45,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
// Create pattern. // Create pattern.
MultiHeadMatmulPattern multihead_pattern(pattern, name_scope); MultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
PDNode* x = multihead_pattern();
pattern->NewNode(patterns::UniqueKey("X"))->assert_var_not_persistable();
multihead_pattern(x);
// Create New OpDesc // Create New OpDesc
auto fuse_creater = [&]( auto fuse_creater = [&](
Node* x, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out, Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
Node* mul1_out, Node* mul2_out, Node* eltadd0_b, Node* eltadd1_b, Node* mul1_out, Node* mul2_out, Node* eltadd0_b, Node* eltadd1_b,
Node* eltadd2_b, Node* eltadd_qk_b, Node* reshape2, Node* eltadd2_b, Node* eltadd_qk_b, Node* reshape2,
Node* reshape2_qkv_out, Node* scale, Node* scale_out) { Node* reshape2_qkv_out, Node* scale, Node* scale_out) {
...@@ -115,7 +112,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -115,7 +112,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern); // GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
...@@ -185,7 +182,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -185,7 +182,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern); multihead_pattern);
fuse_creater(layer_norm, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0,
reshape2_qkv_out, scale, scale_out); reshape2_qkv_out, scale, scale_out);
...@@ -232,12 +229,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -232,12 +229,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
return fusion_count; return fusion_count;
} }
PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) { PDNode* MultiHeadMatmulPattern::operator()() {
// Create shared nodes. auto* input0 = pattern->NewNode(input0_repr());
auto* layer_norm = pattern->NewNode(layer_norm_repr()); input0->assert_is_op_input("mul");
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr());
layer_norm_out_var->assert_is_op_input("mul");
// First path with scale // First path with scale
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul"); auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul");
...@@ -390,17 +384,15 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) { ...@@ -390,17 +384,15 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) {
transpose2_2_out_var->AsIntermediate()->assert_is_op_input( transpose2_2_out_var->AsIntermediate()->assert_is_op_input(
"matmul"); // link to matmul qkv "matmul"); // link to matmul qkv
// Link all nodes together
layer_norm->LinksFrom({x}).LinksTo({layer_norm_out_var});
// Q path // Q path
mul0->LinksFrom({layer_norm_out_var, mul0_w_var}).LinksTo({mul0_out_var}); mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var}); eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
scale->LinksFrom({transpose2_0_out_var}).LinksTo({scale_out_var}); scale->LinksFrom({transpose2_0_out_var}).LinksTo({scale_out_var});
// K path // K path
mul1->LinksFrom({layer_norm_out_var, mul1_w_var}).LinksTo({mul1_out_var}); mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var});
eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var}); eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var});
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var}); reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var}); transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
...@@ -411,7 +403,7 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) { ...@@ -411,7 +403,7 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) {
.LinksTo({eltadd_qk_out_var}); .LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
// V path // V path
mul2->LinksFrom({layer_norm_out_var, mul2_w_var}).LinksTo({mul2_out_var}); mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var});
eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var}); eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var});
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var}); reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var}); transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
...@@ -434,13 +426,10 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -434,13 +426,10 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
// Create pattern. // Create pattern.
MultiHeadMatmulPattern multihead_pattern(pattern, name_scope); MultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
PDNode* x = multihead_pattern();
pattern->NewNode(patterns::UniqueKey("X"))->assert_var_not_persistable();
multihead_pattern(x);
// Create New OpDesc // Create New OpDesc
auto fuse_creater = [&]( auto fuse_creater = [&](
Node* layer_norm_out, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out, Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w, Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w,
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b, Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b,
Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out) { Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out) {
...@@ -471,29 +460,20 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -471,29 +460,20 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
framework::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); framework::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
auto combined_bias_dims = framework::make_ddim({3, bq_tensor->dims()[0]}); auto combined_bias_dims = framework::make_ddim({3, bq_tensor->dims()[0]});
// create a new var in scope // reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
VarDesc combined_w_desc( auto* combined_w_desc = mul0_w->Var();
patterns::PDNodeName(name_scope, "multi_head_combined_weight")); combined_w_desc->SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
combined_w_desc.SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); combined_w_desc->SetPersistable(true);
combined_w_desc.SetDataType(wq_tensor->type());
combined_w_desc.SetLoDLevel(mul0_w->Var()->GetLoDLevel()); auto* combined_bias_desc = eltadd0_b->Var();
combined_w_desc.SetPersistable(true); combined_bias_desc->SetShape({3, bq_tensor->dims()[0]});
combined_bias_desc->SetPersistable(true);
// create a new var in scope
VarDesc combined_bias_desc( framework::LoDTensor tmp_combined_w_tensor;
patterns::PDNodeName(name_scope, "multi_head_combined_bias")); tmp_combined_w_tensor.Resize(combined_w_dims);
combined_bias_desc.SetShape({3, bq_tensor->dims()[0]}); auto* tmp_combined_w_data =
combined_bias_desc.SetDataType(bq_tensor->type()); tmp_combined_w_tensor.mutable_data<float>(platform::CPUPlace());
combined_bias_desc.SetLoDLevel(eltadd0_b->Var()->GetLoDLevel());
combined_bias_desc.SetPersistable(true);
auto* combined_w_node = graph->CreateVarNode(&combined_w_desc);
auto* combined_w_tensor =
scope->Var(combined_w_node->Name())->GetMutable<LoDTensor>();
combined_w_tensor->Resize(combined_w_dims);
auto* combined_w_data =
combined_w_tensor->mutable_data<float>(platform::CPUPlace());
std::vector<float*> w_vec = {wq_data, wk_data, wv_data}; std::vector<float*> w_vec = {wq_data, wk_data, wv_data};
int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2]; int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2];
// Combine the three fc weights together. // Combine the three fc weights together.
...@@ -502,25 +482,38 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -502,25 +482,38 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
for (int k = 0; k < dims_w; k++) { for (int k = 0; k < dims_w; k++) {
int out_index = i * (3 * dims_w) + j * dims_w + k; int out_index = i * (3 * dims_w) + j * dims_w + k;
int in_index = i * dims_w + k; int in_index = i * dims_w + k;
combined_w_data[out_index] = w_vec[j][in_index]; tmp_combined_w_data[out_index] = w_vec[j][in_index];
} }
} }
} }
scope->EraseVars({mul0_w->Name(), mul1_w->Name(), mul2_w->Name()});
auto* combined_bias_node = graph->CreateVarNode(&combined_bias_desc); wq_tensor->Resize(combined_w_dims);
auto* combined_bias_tensor = auto* new_combined_w_data =
scope->Var(combined_bias_node->Name())->GetMutable<LoDTensor>(); wq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_w_data, tmp_combined_w_data,
combined_bias_tensor->Resize(combined_bias_dims); sizeof(float) * wq_tensor->numel());
auto* combined_bias_data =
combined_bias_tensor->mutable_data<float>(platform::CPUPlace()); scope->EraseVars({mul1_w->Name(), mul2_w->Name()});
framework::LoDTensor tmp_combined_bias_tensor;
tmp_combined_bias_tensor.Resize(combined_bias_dims);
auto* tmp_combined_bias_data =
tmp_combined_bias_tensor.mutable_data<float>(platform::CPUPlace());
size_t bias_size = bq_tensor->numel(); size_t bias_size = bq_tensor->numel();
memcpy(combined_bias_data, bq_data, sizeof(float) * bias_size); memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size);
memcpy(combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size); memcpy(tmp_combined_bias_data + bias_size, bk_data,
memcpy(combined_bias_data + 2 * bias_size, bv_data, sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + 2 * bias_size, bv_data,
sizeof(float) * bias_size); sizeof(float) * bias_size);
scope->EraseVars({eltadd0_b->Name(), eltadd1_b->Name(), eltadd2_b->Name()}); bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data =
bq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_bias_data, tmp_combined_bias_data,
sizeof(float) * bq_tensor->numel());
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
auto reshape_desc = reshape2->Op(); auto reshape_desc = reshape2->Op();
int head_number = int head_number =
...@@ -529,9 +522,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -529,9 +522,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
OpDesc multihead_op_desc; OpDesc multihead_op_desc;
multihead_op_desc.SetType("multihead_matmul"); multihead_op_desc.SetType("multihead_matmul");
multihead_op_desc.SetInput("Input", {layer_norm_out->Name()}); multihead_op_desc.SetInput("Input", {input0->Name()});
multihead_op_desc.SetInput("W", {combined_w_node->Name()}); multihead_op_desc.SetInput("W", {mul0_w->Name()});
multihead_op_desc.SetInput("Bias", {combined_bias_node->Name()}); multihead_op_desc.SetInput("Bias", {eltadd0_b->Name()});
multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()}); multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()});
multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()}); multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()});
...@@ -540,9 +533,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -540,9 +533,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
auto* multihead = graph->CreateOpNode(&multihead_op_desc); auto* multihead = graph->CreateOpNode(&multihead_op_desc);
IR_NODE_LINK_TO(layer_norm_out, multihead); IR_NODE_LINK_TO(input0, multihead);
IR_NODE_LINK_TO(combined_w_node, multihead); IR_NODE_LINK_TO(mul0_w, multihead);
IR_NODE_LINK_TO(combined_bias_node, multihead); IR_NODE_LINK_TO(eltadd0_b, multihead);
IR_NODE_LINK_TO(eltadd_qk_b, multihead); IR_NODE_LINK_TO(eltadd_qk_b, multihead);
IR_NODE_LINK_TO(multihead, reshape2_qkv_out); IR_NODE_LINK_TO(multihead, reshape2_qkv_out);
...@@ -552,9 +545,7 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -552,9 +545,7 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern); // GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
...@@ -624,14 +615,13 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -624,14 +615,13 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern); multihead_pattern);
fuse_creater(layer_norm_out, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
mul0_w, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
eltadd_qk_b, reshape2_0, reshape2_qkv_out, scale, scale_out); reshape2_0, reshape2_qkv_out, scale, scale_out);
std::unordered_set<const Node*> marked_nodes({eltadd0, std::unordered_set<const Node*> marked_nodes({eltadd0,
eltadd1, eltadd1,
eltadd2, eltadd2,
eltadd0_b,
eltadd1_b, eltadd1_b,
eltadd2_b, eltadd2_b,
eltadd0_out, eltadd0_out,
...@@ -665,7 +655,6 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -665,7 +655,6 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
mul0_out, mul0_out,
mul1_out, mul1_out,
mul2_out, mul2_out,
mul0_w,
mul1_w, mul1_w,
mul2_w, mul2_w,
reshape2_qkv, reshape2_qkv,
......
...@@ -29,11 +29,10 @@ struct MultiHeadMatmulPattern : public PatternBase { ...@@ -29,11 +29,10 @@ struct MultiHeadMatmulPattern : public PatternBase {
MultiHeadMatmulPattern(PDPattern* pattern, const std::string& name_scope) MultiHeadMatmulPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "multihead_matmul") {} : PatternBase(pattern, name_scope, "multihead_matmul") {}
PDNode* operator()(PDNode* x); PDNode* operator()();
// declare operator node's name // declare operator node's name
PATTERN_DECL_NODE(layer_norm); PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(mul0); PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1); PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul2); PATTERN_DECL_NODE(mul2);
......
...@@ -22,10 +22,7 @@ ...@@ -22,10 +22,7 @@
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
__global__ void test(size_t* a, int size) { __global__ void test(size_t* a, int size) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; CUDA_KERNEL_LOOP(i, size) { a[i] *= 2; }
i += blockDim.x * gridDim.x) {
a[i] *= 2;
}
} }
TEST(LoD, data) { TEST(LoD, data) {
......
...@@ -23,8 +23,13 @@ namespace framework { ...@@ -23,8 +23,13 @@ namespace framework {
void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) { Dataset* dataset) {
pipeline_num_ = trainer_desc.thread_num(); const auto& section_params = trainer_desc.section_param();
VLOG(3) << "pipeline num: " << pipeline_num_; num_microbatches_ = section_params.num_microbatches();
VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_;
section_num_ = section_params.section_config_size();
VLOG(3) << "Number of program sections: " << section_num_;
trainer_desc_ = trainer_desc;
start_cpu_core_id_ = section_params.start_cpu_core_id();
SetDataset(dataset); SetDataset(dataset);
ParseDumpConfig(trainer_desc); ParseDumpConfig(trainer_desc);
...@@ -32,96 +37,62 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -32,96 +37,62 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
const std::vector<paddle::framework::DataFeed*> readers = const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders(); dataset->GetReaders();
VLOG(3) << "readers num: " << readers.size(); VLOG(3) << "readers num: " << readers.size();
int num_readers = readers.size();
pipeline_config_ = trainer_desc.section_param(); PADDLE_ENFORCE_EQ(num_readers, 1,
scope_queue_size_ = pipeline_config_.queue_size(); platform::errors::InvalidArgument(
sync_steps_ = pipeline_config_.sync_steps(); "Number of dataset readers for pipeline "
section_num_ = pipeline_config_.section_config_size(); "must be 1 now, but the value you give is %d.",
num_readers));
VLOG(3) << "scope_queue_size: " << scope_queue_size_; auto* reader = readers[0];
VLOG(3) << "section num: " << section_num_; feed_var_names_ = reader->GetUseSlotAlias();
VLOG(3) << "sync_steps: " << sync_steps_;
workers_.resize(section_num_); workers_.resize(section_num_);
in_var_names_.resize(section_num_);
out_var_names_.resize(section_num_);
worker_count_.resize(section_num_);
worker_count_mutex_.resize(section_num_);
param_need_sync_.reset(new std::vector<std::string>);
int reader_index = 0;
for (int i = 0; i < section_num_; ++i) { for (int i = 0; i < section_num_; ++i) {
const auto& section_config = pipeline_config_.section_config(i); const auto& section_config = section_params.section_config(i);
int concurrency = section_config.concurrency();
VLOG(3) << "the thread num of each pipeline in section " << i
<< " is: " << concurrency;
in_var_names_[i].reset(new std::vector<std::string>(
section_config.section_in_var_names().begin(),
section_config.section_in_var_names().end()));
out_var_names_[i].reset(new std::vector<std::string>(
section_config.section_out_var_names().begin(),
section_config.section_out_var_names().end()));
worker_count_[i].resize(pipeline_num_);
worker_count_mutex_[i].resize(pipeline_num_);
for (int j = 0; j < pipeline_num_; ++j) {
worker_count_[i][j] = new int(concurrency);
worker_count_mutex_[i][j].reset(new std::mutex);
}
platform::Place place; platform::Place place;
workers_[i].resize(pipeline_num_); int place_id = section_config.place_id();
for (int j = 0; j < pipeline_num_; ++j) { switch (section_config.place()) {
workers_[i][j].resize(concurrency); case SectionConfig::CPUPlace:
place = platform::CPUPlace();
switch (section_config.place()) { break;
case SectionConfig::CPUPlace: case SectionConfig::CUDAPlace:
place = platform::CPUPlace(); // Note that one section has at most one GPU place in one pipeline
break; PADDLE_ENFORCE_GE(
case SectionConfig::CUDAPlace: place_id, 0,
// Note that one section has at most one GPU place in one pipeline platform::errors::InvalidArgument(
place = platform::CUDAPlace(j); "The place_id value for CUDAPlace shoud be greater "
break; "than or equal to 0, but the value you give is %d.",
case SectionConfig::CUDAPinnedPlace: place_id));
place = platform::CUDAPinnedPlace(); place = platform::CUDAPlace(place_id);
break; break;
default: case SectionConfig::CUDAPinnedPlace:
PADDLE_ENFORCE(false, "Unkown place type in SectionConfig: %d", place = platform::CUDAPinnedPlace();
section_config.place()); break;
} default:
PADDLE_ENFORCE_NOT_NULL(nullptr,
platform::errors::InvalidArgument(
"Unkown place type in SectionConfig: %d",
section_config.place()));
}
places_.emplace_back(place);
VLOG(3) << "Device worker place: " << place << ", device id: " << place_id
<< ", section: " << i;
for (int k = 0; k < concurrency; ++k) { workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
workers_[i][j][k] = DeviceWorkerFactory::CreateDeviceWorker( trainer_desc.device_worker_name());
trainer_desc.device_worker_name()); auto this_worker =
auto this_worker = std::dynamic_pointer_cast<paddle::framework::SectionWorker>(
std::dynamic_pointer_cast<paddle::framework::SectionWorker>( workers_[i]);
workers_[i][j][k]); if (i == 0) {
this_worker->SetSectionIndex(i); // we only set reader for the first section
this_worker->SetDeviceIndex(j); this_worker->SetDataFeed(reader);
this_worker->SetThreadIndex(k); this_worker->SetReaderPlace(place);
this_worker->SetSectionNum(section_num_);
this_worker->SetPipelineNum(pipeline_num_);
if (i == 0) {
this_worker->SetDataFeed(readers[reader_index++]);
this_worker->SetReaderPlace(place);
}
if (i == section_num_ - 1) {
this_worker->SetNeedDumpField(need_dump_field_);
this_worker->SetNeedDumpParam(need_dump_param_);
this_worker->SetDumpFieldVector(dump_fields_);
this_worker->SetDumpParamVector(dump_param_);
}
this_worker->SetPlace(place);
this_worker->Initialize(trainer_desc);
this_worker->InitRandomDumpConfig(trainer_desc);
}
} }
} this_worker->SetThreadIndex(i);
param_need_sync_.reset( this_worker->SetSectionIndex(i);
new std::vector<std::string>(pipeline_config_.param_need_sync().begin(), this_worker->SetPlace(place);
pipeline_config_.param_need_sync().end())); this_worker->Initialize(trainer_desc);
VLOG(3) << "param_need_sync_ have: "; this_worker->SetMicrobatchNum(num_microbatches_);
for (const std::string& name : *param_need_sync_) {
VLOG(3) << name;
} }
// set debug here // set debug here
SetDebug(trainer_desc.debug()); SetDebug(trainer_desc.debug());
...@@ -140,13 +111,7 @@ std::string PipelineTrainer::GetDumpPath(int tid) { ...@@ -140,13 +111,7 @@ std::string PipelineTrainer::GetDumpPath(int tid) {
void PipelineTrainer::InitDumpEnv() { void PipelineTrainer::InitDumpEnv() {
queue_ = paddle::framework::MakeChannel<std::string>(); queue_ = paddle::framework::MakeChannel<std::string>();
// Only set dump channel on the last section // TODO(sandyhouse): should make it as a config
for (int j = 0; j < pipeline_num_; ++j) {
for (size_t k = 0; k < workers_[section_num_ - 1][j].size(); ++k) {
workers_[section_num_ - 1][j][k]->SetChannelWriter(queue_.get());
}
}
// TODO(hutuxian): should make it as a config
dump_thread_num_ = 1; dump_thread_num_ = 1;
for (int i = 0; i < dump_thread_num_; i++) { for (int i = 0; i < dump_thread_num_; i++) {
dump_thread_.push_back( dump_thread_.push_back(
...@@ -154,150 +119,105 @@ void PipelineTrainer::InitDumpEnv() { ...@@ -154,150 +119,105 @@ void PipelineTrainer::InitDumpEnv() {
} }
} }
void PipelineTrainer::InitFirstScopeQueue(ScopeQueue* scope_queue, void PipelineTrainer::CopyParameters(int section_id, int microbatch_id,
int pipeline_id, const ProgramDesc& program,
const ProgramDesc& main_program, const platform::Place& place) {
const Scope& root_scope) { auto& global_block = program.Block(0);
for (int i = 0; i < scope_queue_size_; ++i) { for (auto& var : global_block.AllVars()) {
Scope* scope = &pipeline_scopes_[pipeline_id]->NewScope(); int is_feed_var =
for (auto& var : main_program.Block(0).AllVars()) { std::count(feed_var_names_.begin(), feed_var_names_.end(), var->Name());
if (!var->Persistable()) { if ((var->Persistable() || is_feed_var) && microbatch_id == 0) {
auto* ptr = scope->Var(var->Name()); if (is_feed_var) {
InitializeVariable(ptr, var->GetType()); auto* new_ptr = minibatch_scopes_[section_id]->Var(var->Name());
VLOG(3) << "data name: " << var->Name() << ", ptr: " << new_ptr;
InitializeVariable(new_ptr, var->GetType());
} else { } else {
if (section_num_ == 1) { // Means only one section and it must be auto* ptr = root_scope_->FindVar(var->Name());
// CUDAPlace, so copy all persistable vars to auto* new_ptr = minibatch_scopes_[section_id]->Var(var->Name());
// pipeline scope VLOG(3) << "Create persistable var " << var->Name() << " for minibatch "
const LoDTensor& root_tensor = << section_id << ", which pointer is " << new_ptr;
root_scope.FindVar(var->Name())->Get<LoDTensor>(); InitializeVariable(new_ptr, var->GetType());
LoDTensor* gpu_tensor = pipeline_scopes_[pipeline_id] const LoDTensor& root_tensor = ptr->Get<LoDTensor>();
->Var(var->Name()) LoDTensor* minibatch_tensor = new_ptr->GetMutable<LoDTensor>();
->GetMutable<LoDTensor>(); TensorCopy(*static_cast<const Tensor*>(&root_tensor), place,
platform::Place place = platform::CUDAPlace(pipeline_id); static_cast<Tensor*>(minibatch_tensor));
TensorCopy(*static_cast<const Tensor*>(&root_tensor), place,
static_cast<Tensor*>(gpu_tensor));
}
} }
} else if (!var->Persistable() && !is_feed_var) {
auto* ptr =
microbatch_scopes_[section_id][microbatch_id]->Var(var->Name());
VLOG(3) << "Create variable " << var->Name() << " for section "
<< section_id << " microbatch " << microbatch_id
<< ", which pointer is " << ptr;
InitializeVariable(ptr, var->GetType());
} }
scope_queue->Send(scope);
} }
} }
void PipelineTrainer::CopyParameters(const Scope& root_scope, int pipeline_id) { void PipelineTrainer::GetSkipVars(int section_id, const ProgramDesc& program) {
for (const std::string& name : *param_need_sync_) { auto& global_block = program.Block(0);
const LoDTensor& root_tensor = root_scope.FindVar(name)->Get<LoDTensor>(); for (auto& op : global_block.AllOps()) {
if (op->Type() != "enqueue") {
// TODO(hutxian): check a new var of the same name is created in continue;
// pipeline_scope }
LoDTensor* gpu_tensor = auto input_arg_names = op->InputArgumentNames();
pipeline_scopes_[pipeline_id]->Var(name)->GetMutable<LoDTensor>(); PADDLE_ENFORCE_EQ(input_arg_names.size(), 1,
platform::Place place = platform::CUDAPlace(pipeline_id); platform::errors::InvalidArgument(
TensorCopy(*static_cast<const Tensor*>(&root_tensor), place, "Number of input arguments for enqueue op must be 1, "
static_cast<Tensor*>(gpu_tensor)); "but the value is %d.",
input_arg_names.size()));
std::string input_arg_name = input_arg_names[0];
if (input_arg_name.rfind("@GRAD") != input_arg_name.size() - 5) {
skip_vars_[section_id].emplace_back(input_arg_name);
VLOG(3) << "add skip var name: " << input_arg_name;
}
} }
} }
void PipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program, void PipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) { const platform::Place& place) {
PADDLE_ENFORCE(root_scope_, "Null root_scope pointer"); PADDLE_ENFORCE_NOT_NULL(root_scope_,
SectionWorker::cpu_id_.store(pipeline_config_.start_cpu_core_id()); platform::errors::InvalidArgument(
scope_queues_.resize(section_num_); "root_scope pointer can not be nullptr"));
pipeline_scopes_.resize(pipeline_num_); auto start_cpu_id = trainer_desc_.section_param().start_cpu_core_id();
for (auto& var : main_program.Block(0).AllVars()) { SectionWorker::cpu_id_.store(start_cpu_id);
if (var->Persistable()) { minibatch_scopes_.resize(section_num_);
persistable_vars_.push_back(var->Name()); microbatch_scopes_.resize(section_num_);
} skip_vars_.resize(section_num_);
}
VLOG(3) << "Init ScopeQueues and create all scopes"; VLOG(3) << "Init ScopeQueues and create all scopes";
for (int i = 0; i < section_num_; ++i) { for (int i = 0; i < section_num_; ++i) {
for (int j = 0; j < pipeline_num_; ++j) { minibatch_scopes_[i] = &root_scope_->NewScope();
scope_queues_[i].emplace_back(new ScopeQueue(scope_queue_size_)); std::shared_ptr<framework::ProgramDesc> program;
if (i == 0) { program.reset(new ProgramDesc(
pipeline_scopes_[j] = &root_scope_->NewScope(); trainer_desc_.section_param().section_config(i).program_desc()));
CopyParameters(*root_scope_, j); microbatch_scopes_[i].resize(num_microbatches_);
InitFirstScopeQueue(scope_queues_[0].back().get(), j, main_program, for (int j = 0; j < num_microbatches_; ++j) {
*root_scope_); microbatch_scopes_[i][j] = &minibatch_scopes_[i]->NewScope();
} CopyParameters(i, j, *program, places_[i]);
} }
GetSkipVars(i, *program);
} }
for (int i = 0; i < section_num_; ++i) { for (int i = 0; i < section_num_; ++i) {
for (int j = 0; j < pipeline_num_; ++j) { auto this_worker =
for (size_t k = 0; k < workers_[i][j].size(); ++k) { std::dynamic_pointer_cast<paddle::framework::SectionWorker>(
auto this_worker = workers_[i]);
std::dynamic_pointer_cast<paddle::framework::SectionWorker>( this_worker->SetRootScope(root_scope_);
workers_[i][j][k]); this_worker->SetMinibatchScope(minibatch_scopes_[i]);
this_worker->SetRootScope(root_scope_); this_worker->SetMicrobatchScopes(microbatch_scopes_[i]);
this_worker->SetCountMutex(worker_count_mutex_[i][j].get()); this_worker->SetSkipVars(skip_vars_[i]);
this_worker->SetWorkerCount(worker_count_[i][j]);
this_worker->SetScopeQueue(scope_queues_[i][j].get(),
(i == section_num_ - 1)
? scope_queues_[0][j].get()
: scope_queues_[i + 1][j].get());
this_worker->SetVarNames(*in_var_names_[i], *out_var_names_[i]);
if (i != section_num_ - 1) {
// For data copy in adjacent different place
this_worker->SetNextSectionPlace(
std::dynamic_pointer_cast<paddle::framework::SectionWorker>(
workers_[i + 1][j][0])
->place());
}
}
}
}
if (pipeline_num_ > 1 && sync_steps_ != -1) {
construct_sync_functor();
}
}
void PipelineTrainer::construct_sync_functor() {
std::vector<platform::Place> cuda_places;
for (int i = 0; i < pipeline_num_; ++i) {
cuda_places.emplace_back(platform::CUDAPlace(i));
}
nccl_ctx_map_.reset(new platform::NCCLContextMap(cuda_places));
sync_functors_.resize(pipeline_num_);
SyncFunctor::sync_flag_ = 0;
SyncFunctor::pipeline_scopes_.resize(0);
for (int j = 0; j < pipeline_num_; ++j) {
SyncFunctor* sync_function = new SyncFunctor(j, pipeline_num_, sync_steps_);
sync_function->SetSyncParam(*param_need_sync_);
sync_function->SetNcclCtxMap(nccl_ctx_map_.get());
SyncFunctor::pipeline_scopes_.push_back(this->pipeline_scopes_[j]);
sync_functors_[j].reset(sync_function);
}
for (int i = section_num_ - 1; i >= 0; --i) {
if (SectionConfig::CUDAPlace ==
pipeline_config_.section_config(i).place()) {
for (int j = 0; j < pipeline_num_; ++j) {
for (size_t k = 0; k < workers_[i][j].size(); ++k) {
auto this_worker =
std::dynamic_pointer_cast<paddle::framework::SectionWorker>(
workers_[i][j][k]);
this_worker->SetSyncFunctor(sync_functors_[j].get());
}
}
break;
}
} }
} }
void PipelineTrainer::Run() { void PipelineTrainer::Run() {
VLOG(3) << "Going to run"; VLOG(3) << "Going to run";
for (int i = 0; i < section_num_; ++i) { for (int i = 0; i < section_num_; ++i) {
for (int j = 0; j < pipeline_num_; ++j) { if (!debug_) {
for (size_t k = 0; k < workers_[i][j].size(); ++k) { section_threads_.push_back(
if (!debug_) { std::thread(&DeviceWorker::TrainFiles, workers_[i].get()));
section_threads_.push_back( } else {
std::thread(&DeviceWorker::TrainFiles, workers_[i][j][k].get())); section_threads_.push_back(std::thread(
} else { &DeviceWorker::TrainFilesWithProfiler, workers_[i].get()));
section_threads_.push_back(std::thread(
&DeviceWorker::TrainFilesWithProfiler, workers_[i][j][k].get()));
}
}
} }
} }
} }
...@@ -309,18 +229,31 @@ void PipelineTrainer::Finalize() { ...@@ -309,18 +229,31 @@ void PipelineTrainer::Finalize() {
if (need_dump_field_) { if (need_dump_field_) {
FinalizeDumpEnv(); FinalizeDumpEnv();
} }
for (const auto& var : persistable_vars_) { VLOG(3) << "copying back parameters. ";
auto* root_tensor = root_scope_->Var(var)->GetMutable<LoDTensor>(); for (int i = 0; i < section_num_; ++i) {
// TODO(hutuxian): Add a final all-reduce? std::shared_ptr<framework::ProgramDesc> program;
const auto& thread_tensor = program.reset(new ProgramDesc(
pipeline_scopes_[0]->FindVar(var)->Get<LoDTensor>(); trainer_desc_.section_param().section_config(i).program_desc()));
TensorCopySync(thread_tensor, platform::CPUPlace(), root_tensor); for (int j = 0; j < num_microbatches_; ++j) {
auto& global_block = program->Block(0);
for (auto& var : global_block.AllVars()) {
if (var->Persistable()) {
auto* ptr = root_scope_->FindVar(var->Name());
LoDTensor* root_tensor = ptr->GetMutable<LoDTensor>();
auto* minibatch_ptr = minibatch_scopes_[i]->Var(var->Name());
const LoDTensor& minibatch_tensor = minibatch_ptr->Get<LoDTensor>();
TensorCopy(*static_cast<const Tensor*>(&minibatch_tensor), places_[0],
static_cast<Tensor*>(root_tensor));
VLOG(4) << "Copy persitable var " << var->Name() << " to root scope";
}
}
}
} }
root_scope_->DropKids(); root_scope_->DropKids();
} }
Scope* PipelineTrainer::GetWorkerScope(int thread_id) { Scope* PipelineTrainer::GetWorkerScope(int thread_id) {
return pipeline_scopes_[thread_id]; return microbatch_scopes_[thread_id][0];
} }
} // end namespace framework } // end namespace framework
......
...@@ -137,49 +137,31 @@ class PipelineTrainer : public TrainerBase { ...@@ -137,49 +137,31 @@ class PipelineTrainer : public TrainerBase {
virtual Scope* GetWorkerScope(int thread_id); virtual Scope* GetWorkerScope(int thread_id);
void InitDumpEnv() override; void InitDumpEnv() override;
virtual std::string GetDumpPath(int tid); virtual std::string GetDumpPath(int tid);
void GetSkipVars(int section_id, const ProgramDesc& main_program);
protected: protected:
int section_num_; int section_num_;
int pipeline_num_; int num_microbatches_;
int scope_queue_size_; int start_cpu_core_id_;
int sync_steps_; std::vector<std::string> feed_var_names_;
std::vector<platform::Place> places_;
std::vector<std::vector<std::string>> skip_vars_;
TrainerDesc trainer_desc_;
SectionWorkerParameter pipeline_config_;
// The in/output var names for each section
std::vector<std::unique_ptr<std::vector<std::string>>> in_var_names_;
std::vector<std::unique_ptr<std::vector<std::string>>> out_var_names_;
// Counter for the running thread
std::vector<std::vector<int*>> worker_count_;
std::vector<std::vector<std::unique_ptr<std::mutex>>> worker_count_mutex_;
// worker: [section_id][pipeline_id][thread_id]
std::vector<std::vector<
std::vector<std::shared_ptr<paddle::framework::DeviceWorker>>>>
workers_;
std::vector<std::thread> section_threads_; std::vector<std::thread> section_threads_;
// worker: [section_id]
// We use scope to maintain context info, and scopes std::vector<std::shared_ptr<paddle::framework::DeviceWorker>> workers_;
// will be deliverd between different sections. // minibatch_scopes_: [section_id]
std::vector<std::vector<std::unique_ptr<ScopeQueue>>> scope_queues_; std::vector<Scope*> minibatch_scopes_;
std::vector<Scope*> pipeline_scopes_; // microbatch_scopes_: [section_id][microbatch_id]
std::vector<std::vector<Scope*>> microbatch_scopes_;
// The parameters that should be syncronized between different cards using
// nccl all-reduce void CopyParameters(int section_id, int microbatch_id,
std::shared_ptr<std::vector<std::string>> param_need_sync_; const ProgramDesc& program, const platform::Place& place);
std::vector<std::string> persistable_vars_; bool isPersistableVarGrad(std::string name);
std::vector<std::unique_ptr<SyncFunctor>> sync_functors_; bool isPersistable(VarDesc* var);
std::shared_ptr<platform::NCCLContextMap> nccl_ctx_map_;
std::vector<DataFeed*> readers_;
void InitFirstScopeQueue(ScopeQueue* scope_queue, int pipeline_id,
const ProgramDesc& main_program,
const Scope& root_scope);
void CopyParameters(const Scope& root_scope, int pipeline_id);
void construct_sync_functor();
}; };
#endif #endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -83,6 +83,7 @@ message SectionWorkerParameter { ...@@ -83,6 +83,7 @@ message SectionWorkerParameter {
optional int64 sync_steps = 3 [ default = 1 ]; optional int64 sync_steps = 3 [ default = 1 ];
optional int32 start_cpu_core_id = 4 [ default = 1 ]; optional int32 start_cpu_core_id = 4 [ default = 1 ];
repeated string param_need_sync = 5; repeated string param_need_sync = 5;
optional int32 num_microbatches = 6;
} }
message SectionConfig { message SectionConfig {
...@@ -99,6 +100,7 @@ message SectionConfig { ...@@ -99,6 +100,7 @@ message SectionConfig {
optional int32 concurrency = 3 [ default = 1 ]; optional int32 concurrency = 3 [ default = 1 ];
repeated string section_in_var_names = 4; repeated string section_in_var_names = 4;
repeated string section_out_var_names = 5; repeated string section_out_var_names = 5;
optional int32 place_id = 6 [ default = -1 ];
} }
message FetchConfig { message FetchConfig {
......
...@@ -205,7 +205,9 @@ void BasicEngine::Execute() { ...@@ -205,7 +205,9 @@ void BasicEngine::Execute() {
continue; continue;
} }
var = std::make_shared<VariableWrapper>(var->Name()); auto tmp_var = std::make_shared<VariableWrapper>(var->Name());
tmp_var->SetType(var->Type());
var = tmp_var;
need_accu_var_list_.emplace_back(iter->second.get(), var); need_accu_var_list_.emplace_back(iter->second.get(), var);
} }
} }
......
...@@ -285,6 +285,11 @@ TEST(test_tracer, test_unique_name_generator) { ...@@ -285,6 +285,11 @@ TEST(test_tracer, test_unique_name_generator) {
auto fc_2 = tracer.GenerateUniqueName("fc"); auto fc_2 = tracer.GenerateUniqueName("fc");
ASSERT_STREQ("fc_0", fc_1.c_str()); ASSERT_STREQ("fc_0", fc_1.c_str());
ASSERT_STREQ("fc_1", fc_2.c_str()); ASSERT_STREQ("fc_1", fc_2.c_str());
// use `eager_tmp` as key if not specify it.
auto tmp_var_2 = tracer.GenerateUniqueName();
ASSERT_STREQ("eager_tmp_2", tmp_var_2.c_str());
auto tmp_var_3 = tracer.GenerateUniqueName("eager_tmp");
ASSERT_STREQ("eager_tmp_3", tmp_var_3.c_str());
} }
TEST(test_tracer, test_current_tracer) { TEST(test_tracer, test_current_tracer) {
......
...@@ -32,7 +32,7 @@ namespace imperative { ...@@ -32,7 +32,7 @@ namespace imperative {
class UniqueNameGenerator { class UniqueNameGenerator {
public: public:
explicit UniqueNameGenerator(std::string prefix = "") : prefix_(prefix) {} explicit UniqueNameGenerator(std::string prefix = "") : prefix_(prefix) {}
std::string Generate(std::string key = "tmp") { std::string Generate(std::string key = "eager_tmp") {
return prefix_ + key + "_" + std::to_string(id_++); return prefix_ + key + "_" + std::to_string(id_++);
} }
...@@ -76,7 +76,14 @@ class Tracer { ...@@ -76,7 +76,14 @@ class Tracer {
return program_desc_tracer_.get(); return program_desc_tracer_.get();
} }
std::string GenerateUniqueName(std::string key = "tmp") { // Note(Aurelius84): The `tmp` is used as prefix key while naming a temporary
// intermediate var both in imperative and static mode. But the
// `UniqueNameGenerator` in C++ and `unique_name.py` in Python doesn't share
// the same auto-increment id. It will create a variable repeatedly with same
// name like `tmp_0` in some cases when transform dygraph into static layers.
// So we modify the default prefix key into `eager_tmp` to distinguish with
// static graph.
std::string GenerateUniqueName(std::string key = "eager_tmp") {
return generator_->Generate(key); return generator_->Generate(key);
} }
......
...@@ -36,7 +36,6 @@ endif() ...@@ -36,7 +36,6 @@ endif()
# fluid_modules exclude API-interface of inference/api and inference/capi # fluid_modules exclude API-interface of inference/api and inference/capi
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
get_property(cuda_modules GLOBAL PROPERTY CUDA_MODULES)
add_subdirectory(api) add_subdirectory(api)
......
...@@ -828,6 +828,25 @@ bool AnalysisPredictor::LoadParameters() { ...@@ -828,6 +828,25 @@ bool AnalysisPredictor::LoadParameters() {
return true; return true;
} }
void AnalysisPredictor::ClearIntermediateTensor() {
PADDLE_ENFORCE_NOT_NULL(inference_program_.get(),
platform::errors::PreconditionNotMet(
"The inference program should be loaded first."));
const auto &global_block = inference_program_->MutableBlock(0);
for (auto *var : global_block->AllVars()) {
if (!IsPersistable(var)) {
const std::string name = var->Name();
auto *variable = executor_->scope()->FindVar(name);
if (variable != nullptr && variable->IsType<framework::LoDTensor>() &&
name != "feed" && name != "fetch") {
VLOG(3) << "Clear Intermediate Tensor: " << name;
auto *t = variable->GetMutable<framework::LoDTensor>();
t->clear();
}
}
}
}
#if PADDLE_WITH_TENSORRT #if PADDLE_WITH_TENSORRT
bool AnalysisPredictor::SaveTrtCalibToDisk() { bool AnalysisPredictor::SaveTrtCalibToDisk() {
PADDLE_ENFORCE(config_.tensorrt_engine_enabled(), PADDLE_ENFORCE(config_.tensorrt_engine_enabled(),
......
...@@ -187,6 +187,12 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -187,6 +187,12 @@ class AnalysisPredictor : public PaddlePredictor {
/// ///
void OptimizeInferenceProgram(); void OptimizeInferenceProgram();
///
/// \brief Clear the intermediate tensors of the predictor
///
///
void ClearIntermediateTensor();
/// ///
/// \brief Get the argument used by predictor /// \brief Get the argument used by predictor
/// ///
......
...@@ -49,6 +49,10 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() { ...@@ -49,6 +49,10 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() {
rules_["matmul"]["Y"] = ScaleAlgo::KL; rules_["matmul"]["Y"] = ScaleAlgo::KL;
rules_["matmul"]["Out"] = ScaleAlgo::KL; rules_["matmul"]["Out"] = ScaleAlgo::KL;
rules_["elementwise_add"]["X"] = ScaleAlgo::KL;
rules_["elementwise_add"]["Y"] = ScaleAlgo::KL;
rules_["elementwise_add"]["Out"] = ScaleAlgo::KL;
// Reshape2 does not perform calculation on the data and shapes are not // Reshape2 does not perform calculation on the data and shapes are not
// changed. Scale is calculated on input data and assign to Quantize and // changed. Scale is calculated on input data and assign to Quantize and
// Dequantize scale. // Dequantize scale.
......
...@@ -27,10 +27,10 @@ ...@@ -27,10 +27,10 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "crypto/cipher.h"
#include "paddle_infer_declare.h" // NOLINT #include "paddle_infer_declare.h" // NOLINT
/*! \namespace paddle
/*! \namespace paddle */
*/
namespace paddle { namespace paddle {
/// \brief Paddle data type. /// \brief Paddle data type.
...@@ -313,6 +313,12 @@ class PD_INFER_DECL PaddlePredictor { ...@@ -313,6 +313,12 @@ class PD_INFER_DECL PaddlePredictor {
/// \return Whether the run is successful /// \return Whether the run is successful
virtual bool ZeroCopyRun() { return false; } virtual bool ZeroCopyRun() { return false; }
///
/// \brief Clear the intermediate tensors of the predictor
///
///
virtual void ClearIntermediateTensor() {}
/// \brief Clone an existing predictor /// \brief Clone an existing predictor
/// When using clone, the same network will be created, /// When using clone, the same network will be created,
/// and the parameters between them are shared. /// and the parameters between them are shared.
...@@ -431,4 +437,6 @@ PD_INFER_DECL std::string get_version(); ...@@ -431,4 +437,6 @@ PD_INFER_DECL std::string get_version();
PD_INFER_DECL std::string UpdateDllFlag(const char* name, const char* value); PD_INFER_DECL std::string UpdateDllFlag(const char* name, const char* value);
PD_INFER_DECL std::shared_ptr<framework::Cipher> MakeCipher(
const std::string& config_file);
} // namespace paddle } // namespace paddle
cc_library(lite_op_teller SRCS op_teller.cc DEPS lite_full_static framework_proto device_context boost xxhash) cc_library(lite_op_teller SRCS op_teller.cc DEPS lite_full_static framework_proto device_context boost xxhash)
cc_library(lite_engine SRCS engine.cc DEPS lite_full_static framework_proto) cc_library(lite_engine SRCS engine.cc DEPS lite_full_static framework_proto)
cc_library(lite_tensor_utils SRCS tensor_utils.cc DEPS memcpy lite_full_static framework_proto boost) cc_library(lite_tensor_utils SRCS tensor_utils.cc DEPS memcpy lite_full_static framework_proto boost device_context)
cc_test(test_lite_engine SRCS test_engine.cc DEPS lite_engine protobuf framework_proto glog gtest analysis) cc_test(test_lite_engine SRCS test_engine.cc DEPS lite_engine protobuf framework_proto glog gtest analysis)
cc_test(test_lite_tensor_utils SRCS test_tensor_utils.cc DEPS lite_engine lite_tensor_utils) cc_test(test_lite_tensor_utils SRCS test_tensor_utils.cc DEPS lite_engine lite_tensor_utils)
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
#include "paddle/fluid/inference/lite/engine.h" #include "paddle/fluid/inference/lite/engine.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h" #include "lite/api/paddle_use_passes.h"
namespace paddle { namespace paddle {
......
...@@ -26,15 +26,14 @@ namespace lite { ...@@ -26,15 +26,14 @@ namespace lite {
// Just tell by the op_types. // Just tell by the op_types.
struct SimpleOpTeller : public Teller { struct SimpleOpTeller : public Teller {
SimpleOpTeller() { SimpleOpTeller() {
const std::map<std::string, std::string>& op2path = std::vector<std::string> lite_ops = paddle::lite::GetAllOps();
paddle::lite::GetOp2PathDict();
auto is_non_inst = [](const std::string& op) -> bool { auto is_non_inst = [](const std::string& op) -> bool {
const std::vector<std::string> ops = {"feed", "fetch", "while"}; const std::vector<std::string> ops = {"feed", "fetch", "while"};
return std::find(ops.begin(), ops.end(), op) != ops.end(); return std::find(ops.begin(), ops.end(), op) != ops.end();
}; };
for (const auto& op : op2path) { for (const auto& op : lite_ops) {
if (!is_non_inst(op.first)) { if (!is_non_inst(op)) {
ops_.insert(op.first); ops_.insert(op);
} }
} }
} }
......
...@@ -30,7 +30,7 @@ TEST(LiteEngineOp, GetNativePlace) { ...@@ -30,7 +30,7 @@ TEST(LiteEngineOp, GetNativePlace) {
platform::Place GetNativePlace(const TargetType& type, int id = 0); platform::Place GetNativePlace(const TargetType& type, int id = 0);
EXPECT_TRUE(platform::is_cpu_place(GetNativePlace(TargetType::kHost))); EXPECT_TRUE(platform::is_cpu_place(GetNativePlace(TargetType::kHost)));
EXPECT_TRUE(platform::is_gpu_place(GetNativePlace(TargetType::kCUDA))); EXPECT_TRUE(platform::is_gpu_place(GetNativePlace(TargetType::kCUDA)));
ASSERT_DEATH(GetNativePlace(TargetType::kUnk), ""); EXPECT_ANY_THROW(GetNativePlace(TargetType::kUnk));
} }
TEST(LiteEngineOp, GetLiteTargetType) { TEST(LiteEngineOp, GetLiteTargetType) {
...@@ -48,8 +48,8 @@ TEST(LiteEngineOp, GetLitePrecisionType) { ...@@ -48,8 +48,8 @@ TEST(LiteEngineOp, GetLitePrecisionType) {
PrecisionType::kInt8); PrecisionType::kInt8);
ASSERT_EQ(GetLitePrecisionType(framework::proto::VarType_Type_INT32), ASSERT_EQ(GetLitePrecisionType(framework::proto::VarType_Type_INT32),
PrecisionType::kInt32); PrecisionType::kInt32);
ASSERT_DEATH( EXPECT_ANY_THROW(
GetLitePrecisionType(framework::proto::VarType_Type_SELECTED_ROWS), ""); GetLitePrecisionType(framework::proto::VarType_Type_SELECTED_ROWS));
} }
TEST(LiteEngineOp, GetNativePrecisionType) { TEST(LiteEngineOp, GetNativePrecisionType) {
...@@ -62,7 +62,7 @@ TEST(LiteEngineOp, GetNativePrecisionType) { ...@@ -62,7 +62,7 @@ TEST(LiteEngineOp, GetNativePrecisionType) {
framework::proto::VarType_Type_INT8); framework::proto::VarType_Type_INT8);
ASSERT_EQ(GetNativePrecisionType(PrecisionType::kInt32), ASSERT_EQ(GetNativePrecisionType(PrecisionType::kInt32),
framework::proto::VarType_Type_INT32); framework::proto::VarType_Type_INT32);
ASSERT_DEATH(GetNativePrecisionType(PrecisionType::kUnk), ""); EXPECT_ANY_THROW(GetNativePrecisionType(PrecisionType::kUnk));
} }
TEST(LiteEngineOp, GetNativeLayoutType) { TEST(LiteEngineOp, GetNativeLayoutType) {
...@@ -70,7 +70,7 @@ TEST(LiteEngineOp, GetNativeLayoutType) { ...@@ -70,7 +70,7 @@ TEST(LiteEngineOp, GetNativeLayoutType) {
framework::DataLayout GetNativeLayoutType(const DataLayoutType& type); framework::DataLayout GetNativeLayoutType(const DataLayoutType& type);
ASSERT_EQ(GetNativeLayoutType(DataLayoutType::kNCHW), ASSERT_EQ(GetNativeLayoutType(DataLayoutType::kNCHW),
framework::DataLayout::kNCHW); framework::DataLayout::kNCHW);
ASSERT_DEATH(GetNativeLayoutType(DataLayoutType::kNHWC), ""); EXPECT_ANY_THROW(GetNativeLayoutType(DataLayoutType::kNHWC));
} }
void test_tensor_copy(const platform::DeviceContext& ctx) { void test_tensor_copy(const platform::DeviceContext& ctx) {
......
...@@ -115,7 +115,18 @@ inline void TransposeQKV(const int batch, const int seq_len, ...@@ -115,7 +115,18 @@ inline void TransposeQKV(const int batch, const int seq_len,
const half *input, half *output, cudaStream_t stream) { const half *input, half *output, cudaStream_t stream) {
int scratch_size = batch * head_num * seq_len * seq_len; int scratch_size = batch * head_num * seq_len * seq_len;
const dim3 grid(seq_len, batch, 3); const dim3 grid(seq_len, batch, 3);
if (head_size % 2 == 0 && scratch_size % 2 == 0) { if (head_size % 8 == 0 && scratch_size % 8 == 0) {
int h = head_size / 8;
const int4 *input4 = reinterpret_cast<const int4 *>(input);
int4 *output4 = reinterpret_cast<int4 *>(output);
dim3 block(h, head_num, 1);
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num, 1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num, head_size, 1024 * 8));
TransposeQkvKernel<int4><<<grid, block, 0, stream>>>(h, input4, output4);
} else if (head_size % 2 == 0 && scratch_size % 2 == 0) {
const int h = head_size / 2; const int h = head_size / 2;
const half2 *input2 = reinterpret_cast<const half2 *>(input); const half2 *input2 = reinterpret_cast<const half2 *>(input);
half2 *output2 = reinterpret_cast<half2 *>(output); half2 *output2 = reinterpret_cast<half2 *>(output);
...@@ -167,7 +178,7 @@ nvinfer1::DimsExprs QkvToContextPluginDynamic::getOutputDimensions( ...@@ -167,7 +178,7 @@ nvinfer1::DimsExprs QkvToContextPluginDynamic::getOutputDimensions(
ret.nbDims = 5; ret.nbDims = 5;
ret.d[0] = inputs[0].d[0]; ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[0].d[1]; ret.d[1] = inputs[0].d[1];
ret.d[2] = expr_builder.constant(hidden_); ret.d[2] = expr_builder.constant(head_size_ * head_number_);
ret.d[3] = expr_builder.constant(1); ret.d[3] = expr_builder.constant(1);
ret.d[4] = expr_builder.constant(1); ret.d[4] = expr_builder.constant(1);
return ret; return ret;
......
...@@ -20,7 +20,7 @@ function(download_int8_data install_dir data_file) ...@@ -20,7 +20,7 @@ function(download_int8_data install_dir data_file)
endif() endif()
endfunction() endfunction()
function(download_qat_data install_dir data_file) function(download_quant_data install_dir data_file)
if (NOT EXISTS ${install_dir}/${data_file}) if (NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file})
endif() endif()
...@@ -85,7 +85,7 @@ function(inference_analysis_api_test_with_fake_data_run TARGET_NAME test_binary ...@@ -85,7 +85,7 @@ function(inference_analysis_api_test_with_fake_data_run TARGET_NAME test_binary
--disable_mkldnn_fc=${disable_fc}) --disable_mkldnn_fc=${disable_fc})
endfunction() endfunction()
function(inference_analysis_api_qat_test_run TARGET_NAME test_binary fp32_model_dir int8_model_dir data_path) function(inference_analysis_api_quant_test_run TARGET_NAME test_binary fp32_model_dir int8_model_dir data_path)
inference_analysis_test_run(${TARGET_NAME} inference_analysis_test_run(${TARGET_NAME}
COMMAND ${test_binary} COMMAND ${test_binary}
ARGS --fp32_model=${fp32_model_dir} ARGS --fp32_model=${fp32_model_dir}
...@@ -249,7 +249,7 @@ if(WITH_MKLDNN) ...@@ -249,7 +249,7 @@ if(WITH_MKLDNN)
## Image classification models ## Image classification models
# ImageNet small dataset # ImageNet small dataset
# May be already downloaded for INT8 QAT unit tests # It may be already downloaded for Quant & INT8 unit tests
set(IMAGENET_DATA_ARCHIVE "imagenet_val_100_tail.tar.gz") set(IMAGENET_DATA_ARCHIVE "imagenet_val_100_tail.tar.gz")
set(IMAGENET_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/imagenet") set(IMAGENET_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/imagenet")
set(IMAGENET_DATA_PATH "${IMAGENET_DATA_DIR}/data.bin") set(IMAGENET_DATA_PATH "${IMAGENET_DATA_DIR}/data.bin")
...@@ -315,21 +315,21 @@ if(WITH_MKLDNN) ...@@ -315,21 +315,21 @@ if(WITH_MKLDNN)
download_int8_data(${INT8_MOBILENET_SSD_MODEL_DIR} "mobilenet_ssd_int8_model.tar.gz" ) download_int8_data(${INT8_MOBILENET_SSD_MODEL_DIR} "mobilenet_ssd_int8_model.tar.gz" )
inference_analysis_api_object_dection_int8_test_run(test_analyzer_int8_mobilenet_ssd ${INT8_OBJ_DETECT_TEST_APP} ${INT8_MOBILENET_SSD_MODEL_DIR} ${PASCALVOC_DATA_PATH}) inference_analysis_api_object_dection_int8_test_run(test_analyzer_int8_mobilenet_ssd ${INT8_OBJ_DETECT_TEST_APP} ${INT8_MOBILENET_SSD_MODEL_DIR} ${PASCALVOC_DATA_PATH})
### optimized FP32 vs. QAT INT8 tests ### optimized FP32 vs. Quant INT8 tests
set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/qat") set(QUANT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/quant")
set(QAT_IMG_CLASS_TEST_APP "test_analyzer_qat_image_classification") set(QUANT_IMG_CLASS_TEST_APP "test_analyzer_quant_image_classification")
set(QAT_IMG_CLASS_TEST_APP_SRC "analyzer_qat_image_classification_tester.cc") set(QUANT_IMG_CLASS_TEST_APP_SRC "analyzer_quant_image_classification_tester.cc")
# build test binary to be used in subsequent tests # build test binary to be used in subsequent tests
inference_analysis_api_test_build(${QAT_IMG_CLASS_TEST_APP} ${QAT_IMG_CLASS_TEST_APP_SRC}) inference_analysis_api_test_build(${QUANT_IMG_CLASS_TEST_APP} ${QUANT_IMG_CLASS_TEST_APP_SRC})
# MobileNet FP32 vs. QAT INT8 # MobileNetV1 FP32 vs. Quant INT8
# The FP32 model should already be downloaded for slim QAT unit tests # The FP32 model should already be downloaded for slim Quant unit tests
set(QAT2_MobileNet_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf") set(QUANT2_MobileNetV1_MODEL_DIR "${QUANT_DATA_DIR}/MobileNetV1_quant2")
set(QAT2_INT8_MobileNet_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf_int8") set(QUANT2_INT8_MobileNetV1_MODEL_DIR "${QUANT_DATA_DIR}/MobileNetV1_quant2_int8")
download_qat_data(${QAT2_INT8_MobileNet_MODEL_DIR} "MobileNet_qat_perf_int8.tar.gz") download_quant_data(${QUANT2_INT8_MobileNetV1_MODEL_DIR} "MobileNet_qat_perf_int8.tar.gz")
inference_analysis_api_qat_test_run(test_analyzer_qat_performance_benchmark ${QAT_IMG_CLASS_TEST_APP} ${QAT2_MobileNet_MODEL_DIR}/MobileNet_qat_perf/float ${QAT2_INT8_MobileNet_MODEL_DIR}/MobileNet_qat_perf_int8 ${IMAGENET_DATA_PATH}) inference_analysis_api_quant_test_run(test_analyzer_quant_performance_benchmark ${QUANT_IMG_CLASS_TEST_APP} ${QUANT2_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf/float ${QUANT2_INT8_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf_int8 ${IMAGENET_DATA_PATH})
### Other tests ### Other tests
......
...@@ -108,7 +108,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs, ...@@ -108,7 +108,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs,
} }
} }
TEST(Analyzer_qat_image_classification, quantization) { TEST(Analyzer_quant_image_classification, quantization) {
AnalysisConfig fp32_cfg; AnalysisConfig fp32_cfg;
SetConfig(&fp32_cfg, FLAGS_fp32_model); SetConfig(&fp32_cfg, FLAGS_fp32_model);
......
...@@ -47,8 +47,8 @@ int test_main(const AnalysisConfig& config, Barrier* barrier = nullptr) { ...@@ -47,8 +47,8 @@ int test_main(const AnalysisConfig& config, Barrier* barrier = nullptr) {
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
predictor->Run(inputs, &outputs); predictor->Run(inputs, &outputs);
const std::vector<float> truth_values = { const std::vector<float> truth_values = {
-0.00621776, -0.00620937, 0.00990623, -0.0039817, -0.00074315, -0.00621776f, -0.00620937f, 0.00990623f, -0.0039817f, -0.00074315f,
0.61229795, -0.00491806, -0.00068755, 0.18409646, 0.30090684}; 0.61229795f, -0.00491806f, -0.00068755f, 0.18409646f, 0.30090684f};
const size_t expected_size = 1; const size_t expected_size = 1;
EXPECT_EQ(outputs.size(), expected_size); EXPECT_EQ(outputs.size(), expected_size);
float* data_o = static_cast<float*>(outputs[0].data.data()); float* data_o = static_cast<float*>(outputs[0].data.data());
......
...@@ -49,15 +49,17 @@ TEST(AnalysisPredictor, use_gpu) { ...@@ -49,15 +49,17 @@ TEST(AnalysisPredictor, use_gpu) {
ASSERT_TRUE(predictor->Run(inputs, &outputs)); ASSERT_TRUE(predictor->Run(inputs, &outputs));
const std::vector<float> truth_values = { const std::vector<float> truth_values = {
127.780396, 738.16656, 1013.2264, -438.17206, 366.4022, 927.66187, 127.780396f, 738.16656f, 1013.2264f, -438.17206f, 366.4022f,
736.2241, -633.68567, -329.92737, -430.15637, -633.0639, -146.54858, 927.66187f, 736.2241f, -633.68567f, -329.92737f, -430.15637f,
-1324.2804, -1349.3661, -242.67671, 117.44864, -801.7251, -391.51495, -633.0639f, -146.54858f, -1324.2804f, -1349.3661f, -242.67671f,
-404.8202, 454.16132, 515.48206, -133.03114, 69.293076, 590.09753, 117.44864f, -801.7251f, -391.51495f, -404.8202f, 454.16132f,
-1434.6917, -1070.8903, 307.0744, 400.52573, -316.12177, -587.1265, 515.48206f, -133.03114f, 69.293076f, 590.09753f, -1434.6917f,
-161.05742, 800.3663, -96.47157, 748.708, 868.17645, -447.9403, -1070.8903f, 307.0744f, 400.52573f, -316.12177f, -587.1265f,
112.73656, 1127.1992, 47.43518, 677.7219, 593.1881, -336.4011, -161.05742f, 800.3663f, -96.47157f, 748.708f, 868.17645f,
551.3634, 397.82474, 78.39835, -715.4006, 405.96988, 404.25684, -447.9403f, 112.73656f, 1127.1992f, 47.43518f, 677.7219f,
246.01978, -8.430191, 131.36617, -648.0528}; 593.1881f, -336.4011f, 551.3634f, 397.82474f, 78.39835f,
-715.4006f, 405.96988f, 404.25684f, 246.01978f, -8.430191f,
131.36617f, -648.0528f};
const size_t expected_size = 1; const size_t expected_size = 1;
EXPECT_EQ(outputs.size(), expected_size); EXPECT_EQ(outputs.size(), expected_size);
......
...@@ -43,6 +43,7 @@ TEST(AnalysisPredictor, use_gpu) { ...@@ -43,6 +43,7 @@ TEST(AnalysisPredictor, use_gpu) {
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
for (auto& input : inputs_all) { for (auto& input : inputs_all) {
ASSERT_TRUE(predictor->Run(input, &outputs)); ASSERT_TRUE(predictor->Run(input, &outputs));
predictor->ClearIntermediateTensor();
} }
} }
......
...@@ -91,7 +91,7 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_ten ...@@ -91,7 +91,7 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_ten
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost)
if (WITH_GPU) if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor)
endif() endif()
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册