提交 d1d6c268 编写于 作者: S sandyhouse

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_timeline

...@@ -168,6 +168,9 @@ if(WITH_BRPC_RDMA) ...@@ -168,6 +168,9 @@ if(WITH_BRPC_RDMA)
endif() endif()
endif() endif()
# lite subgraph compilation depends on CUDNN_ROOT,
# so include(cudnn) needs to be in front of include(third_party/lite)
include(cudnn) # set cudnn libraries, must before configure
include(third_party) # download, build, install third_party include(third_party) # download, build, install third_party
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
...@@ -187,7 +190,6 @@ if(NOT WIN32) ...@@ -187,7 +190,6 @@ if(NOT WIN32)
endif() endif()
include(flags) # set paddle compile flags include(flags) # set paddle compile flags
include(cudnn) # set cudnn libraries, must before configure
if(WITH_GPU) if(WITH_GPU)
include(cuda) include(cuda)
...@@ -216,6 +218,9 @@ endif(WITH_AMD_GPU) ...@@ -216,6 +218,9 @@ endif(WITH_AMD_GPU)
if(WITH_ARM) if(WITH_ARM)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(WITH_XBYAK OFF CACHE STRING "Disable XBYAK when compiling WITH_ARM=ON" FORCE)
set(WITH_MKL OFF CACHE STRING "Disable MKL when compiling WITH_ARM=ON." FORCE)
set(WITH_GPU OFF CACHE STRING "Disable GPU when compiling WITH_ARM=ON." FORCE)
add_definitions(-DPADDLE_WITH_ARM) add_definitions(-DPADDLE_WITH_ARM)
endif() endif()
......
 
# PaddlePaddle <p align="center">
<img align="center" src="doc/imgs/logo.png", width=1600>
<p>
--------------------------------------------------------------------------------
English | [简体中文](./README_cn.md) English | [简体中文](./README_cn.md)
...@@ -29,7 +33,7 @@ pip install paddlepaddle ...@@ -29,7 +33,7 @@ pip install paddlepaddle
# Linux GPU cuda10cudnn7 # Linux GPU cuda10cudnn7
pip install paddlepaddle-gpu pip install paddlepaddle-gpu
# Linux GPU cuda9cudnn7 # Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu==1.8.2.post97 pip install paddlepaddle-gpu==1.8.3.post97
``` ```
It is recommended to read [this doc](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/install/index_en.html) on our website. It is recommended to read [this doc](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/install/index_en.html) on our website.
......
 
# PaddlePaddle <p align="center">
<img align="center" src="doc/imgs/logo.png", width=1600>
<p>
--------------------------------------------------------------------------------
[English](./README.md) | 简体中文 [English](./README.md) | 简体中文
...@@ -26,7 +30,7 @@ pip install paddlepaddle ...@@ -26,7 +30,7 @@ pip install paddlepaddle
# Linux GPU cuda10cudnn7 # Linux GPU cuda10cudnn7
pip install paddlepaddle-gpu pip install paddlepaddle-gpu
# Linux GPU cuda9cudnn7 # Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu==1.8.2.post97 pip install paddlepaddle-gpu==1.8.3.post97
``` ```
更多安装信息详见官网 [安装说明](http://www.paddlepaddle.org.cn/documentation/docs/zh/1.8/beginners_guide/install/index_cn.html) 更多安装信息详见官网 [安装说明](http://www.paddlepaddle.org.cn/documentation/docs/zh/1.8/beginners_guide/install/index_cn.html)
......
...@@ -18,6 +18,15 @@ if(NOT LINUX OR NOT WITH_MKL) ...@@ -18,6 +18,15 @@ if(NOT LINUX OR NOT WITH_MKL)
return() return()
endif() endif()
if(XPU_SDK_ROOT)
set(LITE_WITH_XPU ON)
include_directories("${XPU_SDK_ROOT}/XTDK/include")
include_directories("${XPU_SDK_ROOT}/XTCL/include")
add_definitions(-DPADDLE_WITH_XPU)
LINK_DIRECTORIES("${XPU_SDK_ROOT}/XTDK/shlib/")
LINK_DIRECTORIES("${XPU_SDK_ROOT}/XTDK/runtime/shlib/")
endif()
if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
include(ExternalProject) include(ExternalProject)
set(LITE_PROJECT extern_lite) set(LITE_PROJECT extern_lite)
...@@ -25,7 +34,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) ...@@ -25,7 +34,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite) set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite)
if(NOT LITE_GIT_TAG) if(NOT LITE_GIT_TAG)
set(LITE_GIT_TAG ab8af5c4b4dc5b40217633e0aa436315912d7b53) set(LITE_GIT_TAG 42ab4d559f6659edfc35040fb30fdcec3dc3f8aa)
endif() endif()
if(NOT CUDA_ARCH_NAME) if(NOT CUDA_ARCH_NAME)
...@@ -47,6 +56,8 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) ...@@ -47,6 +56,8 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
-DCUDNN_ROOT=${CUDNN_ROOT} -DCUDNN_ROOT=${CUDNN_ROOT}
-DLITE_WITH_STATIC_CUDA=OFF -DLITE_WITH_STATIC_CUDA=OFF
-DCUDA_ARCH_NAME=${CUDA_ARCH_NAME} -DCUDA_ARCH_NAME=${CUDA_ARCH_NAME}
-DLITE_WITH_XPU=${LITE_WITH_XPU}
-DXPU_SDK_ROOT=${XPU_SDK_ROOT}
-DLITE_WITH_ARM=OFF) -DLITE_WITH_ARM=OFF)
ExternalProject_Add( ExternalProject_Add(
...@@ -83,7 +94,7 @@ message(STATUS "Paddle-lite SOURCE_DIR: ${LITE_SOURCE_DIR}") ...@@ -83,7 +94,7 @@ message(STATUS "Paddle-lite SOURCE_DIR: ${LITE_SOURCE_DIR}")
include_directories(${LITE_SOURCE_DIR}) include_directories(${LITE_SOURCE_DIR})
include_directories(${LITE_BINARY_DIR}) include_directories(${LITE_BINARY_DIR})
function(external_lite_static_libs alias path) function(external_lite_libs alias path)
add_library(${alias} SHARED IMPORTED GLOBAL) add_library(${alias} SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ${alias} PROPERTY IMPORTED_LOCATION SET_PROPERTY(TARGET ${alias} PROPERTY IMPORTED_LOCATION
${path}) ${path})
...@@ -92,7 +103,8 @@ function(external_lite_static_libs alias path) ...@@ -92,7 +103,8 @@ function(external_lite_static_libs alias path)
endif() endif()
endfunction() endfunction()
external_lite_static_libs(lite_full_static ${LITE_BINARY_DIR}/inference_lite_lib/cxx/lib/libpaddle_full_api_shared.so) external_lite_libs(lite_full_static ${LITE_BINARY_DIR}/inference_lite_lib/cxx/lib/libpaddle_full_api_shared.so)
set(LITE_SHARED_LIB ${LITE_BINARY_DIR}/inference_lite_lib/cxx/lib/libpaddle_full_api_shared.so)
add_definitions(-DPADDLE_WITH_LITE) add_definitions(-DPADDLE_WITH_LITE)
add_definitions(-DLITE_WITH_LOG) add_definitions(-DLITE_WITH_LOG)
...@@ -36,28 +36,12 @@ SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/${LIBDIR ...@@ -36,28 +36,12 @@ SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/${LIBDIR
INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR}) # For MKLDNN code to include internal headers. INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR}) # For MKLDNN code to include internal headers.
IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
SET(MKLDNN_DEPENDS ${MKLML_PROJECT})
MESSAGE(STATUS "Build MKLDNN with MKLML ${MKLML_ROOT}")
ELSE()
MESSAGE(STATUS "Build MKLDNN without MKLML")
ENDIF()
IF(NOT WIN32) IF(NOT WIN32)
SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-error=array-bounds") SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-error=array-bounds")
SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value") SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value")
SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}") SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}")
SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}") SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}")
IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
# Force libmkldnn.so to link libiomp5.so (provided by intel mkl) instead of libgomp.so (provided by gcc),
# since core_avx.so links libiomp5.so
set(MKLDNN_SHARED_LINKER_FLAG "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--no-as-needed -L${MKLML_LIB_DIR} -liomp5")
set(FORBID "-fopenmp")
ELSE()
set(MKLDNN_SHARED_LINKER_FLAG "${CMAKE_SHARED_LINKER_FLAGS}")
set(FORBID "")
ENDIF()
ELSE() ELSE()
SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} /EHsc") SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} /EHsc")
ENDIF(NOT WIN32) ENDIF(NOT WIN32)
...@@ -91,8 +75,6 @@ ExternalProject_Add( ...@@ -91,8 +75,6 @@ ExternalProject_Add(
-DCMAKE_C_FLAGS=${MKLDNN_CFLAG} -DCMAKE_C_FLAGS=${MKLDNN_CFLAG}
-DCMAKE_CXX_FLAGS=${MKLDNN_CXXFLAG} -DCMAKE_CXX_FLAGS=${MKLDNN_CXXFLAG}
-DDNNL_BUILD_TESTS=OFF -DDNNL_BUILD_EXAMPLES=OFF -DDNNL_BUILD_TESTS=OFF -DDNNL_BUILD_EXAMPLES=OFF
-DCMAKE_SHARED_LINKER_FLAGS=${MKLDNN_SHARED_LINKER_FLAG}
-DCMAKE_CXX_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS=${FORBID}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR}
) )
if(WIN32) if(WIN32)
......
...@@ -20,6 +20,8 @@ SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas) ...@@ -20,6 +20,8 @@ SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas)
SET(CBLAS_REPOSITORY https://github.com/xianyi/OpenBLAS.git) SET(CBLAS_REPOSITORY https://github.com/xianyi/OpenBLAS.git)
SET(CBLAS_TAG v0.3.7) SET(CBLAS_TAG v0.3.7)
IF(WITH_ARM) IF(WITH_ARM)
# Under the FT2000 architecture, the calculation result of blas.sgemm in openblas 0.3+ is wrong,
# so version 0.2 is used by default.
SET(CBLAS_TAG v0.2.18) SET(CBLAS_TAG v0.2.18)
ENDIF() ENDIF()
cache_third_party(extern_openblas cache_third_party(extern_openblas
......
...@@ -145,9 +145,9 @@ if (NOT "${PROTOBUF_ROOT}" STREQUAL "") ...@@ -145,9 +145,9 @@ if (NOT "${PROTOBUF_ROOT}" STREQUAL "")
find_program(PROTOBUF_PROTOC_EXECUTABLE protoc PATHS ${PROTOBUF_ROOT}/bin NO_DEFAULT_PATH) find_program(PROTOBUF_PROTOC_EXECUTABLE protoc PATHS ${PROTOBUF_ROOT}/bin NO_DEFAULT_PATH)
if (PROTOBUF_INCLUDE_DIR AND PROTOBUF_LIBRARY AND PROTOBUF_LITE_LIBRARY AND PROTOBUF_PROTOC_LIBRARY AND PROTOBUF_PROTOC_EXECUTABLE) if (PROTOBUF_INCLUDE_DIR AND PROTOBUF_LIBRARY AND PROTOBUF_LITE_LIBRARY AND PROTOBUF_PROTOC_LIBRARY AND PROTOBUF_PROTOC_EXECUTABLE)
SET(PROTOBUF_FOUND true) SET(PROTOBUF_FOUND true)
message(STATUS "Using custom protobuf library in ${PROTOBUF_ROOT}.")
SET_PROTOBUF_VERSION() SET_PROTOBUF_VERSION()
PROMPT_PROTOBUF_LIB() PROMPT_PROTOBUF_LIB()
message(STATUS "Using custom protobuf library in ${PROTOBUF_ROOT}.")
endif() endif()
endif() endif()
......
...@@ -8,6 +8,8 @@ function(CheckCompilerCXX11Flag) ...@@ -8,6 +8,8 @@ function(CheckCompilerCXX11Flag)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.") message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
elseif(${CMAKE_CXX_COMPILER_VERSION} VERSION_GREATER 8.2)
message(WARNING "Found GCC ${CMAKE_CXX_COMPILER_VERSION} which is too high, recommended to use GCC 8.2")
endif() endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang" # cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
......
...@@ -819,20 +819,18 @@ function(brpc_library TARGET_NAME) ...@@ -819,20 +819,18 @@ function(brpc_library TARGET_NAME)
cc_library("${TARGET_NAME}" SRCS "${brpc_library_SRCS}" DEPS "${TARGET_NAME}_proto" "${brpc_library_DEPS}") cc_library("${TARGET_NAME}" SRCS "${brpc_library_SRCS}" DEPS "${TARGET_NAME}_proto" "${brpc_library_DEPS}")
endfunction() endfunction()
# copy_if_different from src_file to dst_file before barrier_target. # copy_if_different from src_file to dst_file At the beginning of the build.
function(copy_if_different src_file dst_file barrier_target) function(copy_if_different src_file dst_file)
# this is a dummy target, should always be run to update ${dst_file} get_filename_component(FILE_NAME ${dst_file} NAME_WE)
add_custom_target(before_${barrier_target} ALL
DEPENDS before_${barrier_target}_custom_command
)
add_dependencies(${barrier_target} before_${barrier_target})
add_custom_command( # this is a dummy target for custom command, should always be run firstly to update ${dst_file}
OUTPUT before_${barrier_target}_custom_command add_custom_target(copy_${FILE_NAME}_command ALL
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${src_file} ${dst_file} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${src_file} ${dst_file}
COMMENT "copy_if_different ${dst_file}" COMMENT "copy_if_different ${dst_file}"
VERBATIM VERBATIM
) )
add_dependencies(extern_glog copy_${FILE_NAME}_command)
endfunction() endfunction()
# create a dummy source file, then create a static library. # create a dummy source file, then create a static library.
......
...@@ -19,9 +19,12 @@ set(FLUID_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_install_dir" CACHE STRING ...@@ -19,9 +19,12 @@ set(FLUID_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_install_dir" CACHE STRING
set(FLUID_INFERENCE_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_install_dir" CACHE STRING set(FLUID_INFERENCE_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_install_dir" CACHE STRING
"A path setting fluid inference shared and static libraries") "A path setting fluid inference shared and static libraries")
# TODO(zhaolong)
# At present, the size of static lib in Windows exceeds the system limit,
# so the generation of static lib is temporarily turned off.
if(WIN32) if(WIN32)
#todo: remove the option #todo: remove the option
option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON) option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." OFF)
if(NOT PYTHON_EXECUTABLE) if(NOT PYTHON_EXECUTABLE)
FIND_PACKAGE(PythonInterp REQUIRED) FIND_PACKAGE(PythonInterp REQUIRED)
endif() endif()
...@@ -187,21 +190,18 @@ copy(inference_lib_dist ...@@ -187,21 +190,18 @@ copy(inference_lib_dist
SRCS ${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io/crypto/cipher.h SRCS ${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io/crypto/cipher.h
DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include/crypto/) DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include/crypto/)
include_directories(${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io) include_directories(${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io)
# CAPI inference library for only inference # CAPI inference library for only inference
set(FLUID_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_c_install_dir" CACHE STRING set(FLUID_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_c_install_dir" CACHE STRING
"A path setting CAPI fluid inference shared") "A path setting CAPI fluid inference shared")
copy_part_of_thrid_party(inference_lib_dist ${FLUID_INFERENCE_C_INSTALL_DIR}) copy_part_of_thrid_party(inference_lib_dist ${FLUID_INFERENCE_C_INSTALL_DIR})
set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid") set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid")
if(WIN32) set(paddle_fluid_c_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/capi/libpaddle_fluid_c.*)
set(paddle_fluid_c_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/capi/${CMAKE_BUILD_TYPE}/paddle_fluid_c.*)
else(WIN32)
set(paddle_fluid_c_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/capi/libpaddle_fluid_c.*)
endif(WIN32)
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${src_dir}/inference/capi/paddle_c_api.h ${paddle_fluid_c_lib} SRCS ${src_dir}/inference/capi/paddle_c_api.h ${paddle_fluid_c_lib}
DSTS ${FLUID_INFERENCE_C_INSTALL_DIR}/paddle/include ${FLUID_INFERENCE_C_INSTALL_DIR}/paddle/lib) DSTS ${FLUID_INFERENCE_C_INSTALL_DIR}/paddle/include ${FLUID_INFERENCE_C_INSTALL_DIR}/paddle/lib)
# fluid library for both train and inference # fluid library for both train and inference
set(fluid_lib_deps inference_lib_dist) set(fluid_lib_deps inference_lib_dist)
......
...@@ -7,14 +7,14 @@ if(WIN32) ...@@ -7,14 +7,14 @@ if(WIN32)
return() return()
endif() endif()
set(NCCL_ROOT "/usr" CACHE PATH "NCCL ROOT")
find_path(NCCL_INCLUDE_DIR nccl.h
PATHS ${NCCL_ROOT} ${NCCL_ROOT}/include ${NCCL_ROOT}/local/include
$ENV{NCCL_ROOT} $ENV{NCCL_ROOT}/include $ENV{NCCL_ROOT}/local/include
NO_DEFAULT_PATH
)
if(WITH_NCCL) if(WITH_NCCL)
set(NCCL_ROOT "/usr" CACHE PATH "NCCL ROOT")
find_path(NCCL_INCLUDE_DIR nccl.h
PATHS ${NCCL_ROOT} ${NCCL_ROOT}/include ${NCCL_ROOT}/local/include
$ENV{NCCL_ROOT} $ENV{NCCL_ROOT}/include $ENV{NCCL_ROOT}/local/include
NO_DEFAULT_PATH
)
file(READ ${NCCL_INCLUDE_DIR}/nccl.h NCCL_VERSION_FILE_CONTENTS) file(READ ${NCCL_INCLUDE_DIR}/nccl.h NCCL_VERSION_FILE_CONTENTS)
string(REGEX MATCH "define NCCL_VERSION_CODE +([0-9]+)" string(REGEX MATCH "define NCCL_VERSION_CODE +([0-9]+)"
......
...@@ -114,7 +114,7 @@ function(op_library TARGET) ...@@ -114,7 +114,7 @@ function(op_library TARGET)
endif() endif()
# Define operators that don't need pybind here. # Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_reduce_op" "compare_op" "logical_op" "nccl_op" foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
......
...@@ -63,7 +63,8 @@ class Array { ...@@ -63,7 +63,8 @@ class Array {
HOSTDEVICE inline const T &at(size_t i) const { HOSTDEVICE inline const T &at(size_t i) const {
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
PADDLE_ENFORCE_LT(i, N, "Array index out of bounds"); PADDLE_ENFORCE_LT(
i, N, platform::errors::OutOfRange("Array index out of bounds."));
#endif #endif
return (*this)[i]; return (*this)[i];
} }
...@@ -106,7 +107,7 @@ class Array<T, 0> { ...@@ -106,7 +107,7 @@ class Array<T, 0> {
static T obj(); static T obj();
return obj; return obj;
#else #else
PADDLE_THROW("Array<T, 0> has no element"); PADDLE_THROW(platform::errors::Unavailable("Array<T, 0> has no element."));
#endif #endif
} }
...@@ -115,7 +116,7 @@ class Array<T, 0> { ...@@ -115,7 +116,7 @@ class Array<T, 0> {
static const T obj(); static const T obj();
return obj; return obj;
#else #else
PADDLE_THROW("Array<T, 0> has no element"); PADDLE_THROW(platform::errors::Unavailable("Array<T, 0> has no element."));
#endif #endif
} }
......
...@@ -77,11 +77,13 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -77,11 +77,13 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for (auto var_name : fetch_var_names) { for (auto var_name : fetch_var_names) {
auto var_desc = block.FindVar(var_name); auto var_desc = block.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var_desc, platform::errors::NotFound("%s is not found.", var_name)); var_desc, platform::errors::NotFound(
"Variable %s is not found in main program.", var_name));
auto shapes = var_desc->GetShape(); auto shapes = var_desc->GetShape();
PADDLE_ENFORCE(shapes[shapes.size() - 1] == 1, PADDLE_ENFORCE_EQ(shapes[shapes.size() - 1], 1,
"var %s: Fetched var has wrong shape, " platform::errors::InvalidArgument(
"only variables with the last dimension size 1 supported", "Fetched variable %s has wrong shape, "
"only variables whose last dimension is 1 are supported",
var_name); var_name);
} }
...@@ -95,7 +97,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -95,7 +97,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
actual_thread_num_ = thread_num; actual_thread_num_ = thread_num;
int file_cnt = filelist.size(); int file_cnt = filelist.size();
PADDLE_ENFORCE_GT(file_cnt, 0, PADDLE_ENFORCE_GT(file_cnt, 0,
platform::errors::NotFound("Input file list is empty")); platform::errors::NotFound("Input file list is empty."));
if (actual_thread_num_ > file_cnt) { if (actual_thread_num_ > file_cnt) {
VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt
......
...@@ -72,7 +72,8 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) { ...@@ -72,7 +72,8 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
return val; return val;
} }
default: default:
PADDLE_THROW("Unsupport attr type %d", attr_desc.type()); PADDLE_THROW(platform::errors::Unavailable("Unsupport attribute type %d.",
attr_desc.type()));
} }
return boost::blank(); return boost::blank();
} }
......
...@@ -37,9 +37,10 @@ struct ExtractAttribute { ...@@ -37,9 +37,10 @@ struct ExtractAttribute {
try { try {
attr_value = &boost::get<T>(attr); attr_value = &boost::get<T>(attr);
} catch (boost::bad_get& bad_get) { } catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s", PADDLE_THROW(platform::errors::InvalidArgument(
attr_name_, paddle::platform::demangle(typeid(T).name()), "Cannot get attribute (%s) by type %s, its type is %s.", attr_name_,
paddle::platform::demangle(attr.type().name())); paddle::platform::demangle(typeid(T).name()),
paddle::platform::demangle(attr.type().name())));
} }
return attr_value; return attr_value;
} }
...@@ -70,8 +71,9 @@ struct ExtractAttribute<bool> { ...@@ -70,8 +71,9 @@ struct ExtractAttribute<bool> {
try { try {
attr_value = &boost::get<bool>(attr); attr_value = &boost::get<bool>(attr);
} catch (boost::bad_get& bad_get) { } catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s", PADDLE_THROW(platform::errors::InvalidArgument(
attr_name_, paddle::platform::demangle(attr.type().name())); "Cannot get attribute (%s) by type bool, its type is %s.", attr_name_,
paddle::platform::demangle(attr.type().name())));
} }
return attr_value; return attr_value;
} }
...@@ -96,8 +98,9 @@ struct ExtractAttribute<int64_t> { ...@@ -96,8 +98,9 @@ struct ExtractAttribute<int64_t> {
try { try {
attr_value = &boost::get<int64_t>(attr); attr_value = &boost::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) { } catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s", PADDLE_THROW(platform::errors::InvalidArgument(
attr_name_, paddle::platform::demangle(attr.type().name())); "Cannot get attribute (%s) by type int64_t, its type is %s.",
attr_name_, paddle::platform::demangle(attr.type().name())));
} }
return attr_value; return attr_value;
} }
...@@ -124,8 +127,10 @@ struct ExtractAttribute<std::vector<int64_t>> { ...@@ -124,8 +127,10 @@ struct ExtractAttribute<std::vector<int64_t>> {
try { try {
attr_value = &boost::get<std::vector<int64_t>>(attr); attr_value = &boost::get<std::vector<int64_t>>(attr);
} catch (boost::bad_get& bad_get) { } catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s", PADDLE_THROW(platform::errors::InvalidArgument(
attr_name_, paddle::platform::demangle(attr.type().name())); "Cannot get attribute (%s) by type std::vector<int64_t>, its type is "
"%s.",
attr_name_, paddle::platform::demangle(attr.type().name())));
} }
return attr_value; return attr_value;
} }
...@@ -150,8 +155,9 @@ struct ExtractAttribute<float> { ...@@ -150,8 +155,9 @@ struct ExtractAttribute<float> {
try { try {
attr_value = &boost::get<float>(attr); attr_value = &boost::get<float>(attr);
} catch (boost::bad_get& bad_get) { } catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type float, its type is %s", PADDLE_THROW(platform::errors::InvalidArgument(
attr_name_, paddle::platform::demangle(attr.type().name())); "Cannot get attribute (%s) by type float, its type is %s.",
attr_name_, paddle::platform::demangle(attr.type().name())));
} }
return attr_value; return attr_value;
} }
...@@ -173,8 +179,9 @@ class AttrReader { ...@@ -173,8 +179,9 @@ class AttrReader {
template <typename T> template <typename T>
inline const T& Get(const std::string& name) const { inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", PADDLE_ENFORCE_NE(attrs_.count(name), 0,
name); platform::errors::NotFound(
"Attribute (%s) should be in AttributeMap.", name));
Attribute& attr = const_cast<Attribute&>(attrs_.at(name)); Attribute& attr = const_cast<Attribute&>(attrs_.at(name));
ExtractAttribute<T> extract_attr(name); ExtractAttribute<T> extract_attr(name);
...@@ -192,8 +199,10 @@ class GreaterThanChecker { ...@@ -192,8 +199,10 @@ class GreaterThanChecker {
public: public:
explicit GreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} explicit GreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(const T& value) const { void operator()(const T& value) const {
PADDLE_ENFORCE_GT(value, lower_bound_, PADDLE_ENFORCE_GT(
platform::errors::OutOfRange("larger_than check fails.")); value, lower_bound_,
platform::errors::OutOfRange(
"Check for attribute value greater than a certain value failed."));
} }
private: private:
...@@ -205,7 +214,10 @@ class EqualGreaterThanChecker { ...@@ -205,7 +214,10 @@ class EqualGreaterThanChecker {
public: public:
explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(const T& value) const { void operator()(const T& value) const {
PADDLE_ENFORCE_GE(value, lower_bound_, "equal_larger_than check fails."); PADDLE_ENFORCE_GE(
value, lower_bound_,
platform::errors::OutOfRange("Check for attribute valur equal or "
"greater than a certain value failed."));
} }
private: private:
...@@ -231,9 +243,10 @@ class EnumInContainer { ...@@ -231,9 +243,10 @@ class EnumInContainer {
public: public:
explicit EnumInContainer(const std::unordered_set<T>& c) : container_(c) {} explicit EnumInContainer(const std::unordered_set<T>& c) : container_(c) {}
void operator()(const T& val) const { void operator()(const T& val) const {
PADDLE_ENFORCE(container_.find(val) != container_.end(), PADDLE_ENFORCE_NE(
"Value %s is not in enum container %s", val, container_.find(val), container_.end(),
ContainerDebugString()); platform::errors::NotFound("Value %s is not in enum container %s.", val,
ContainerDebugString()));
} }
private: private:
...@@ -284,8 +297,11 @@ class TypedAttrChecker { ...@@ -284,8 +297,11 @@ class TypedAttrChecker {
// we can add more common limits, like LessThan(), Between()... // we can add more common limits, like LessThan(), Between()...
TypedAttrChecker& SetDefault(const T& default_value) { TypedAttrChecker& SetDefault(const T& default_value) {
PADDLE_ENFORCE(default_value_setter_.empty(), PADDLE_ENFORCE_EQ(
"%s can't have more than one default value!", attr_name_); default_value_setter_.empty(), true,
platform::errors::AlreadyExists(
"Attribute (%s) has a default value and cannot be set repeatedly.",
attr_name_));
default_value_setter_.push_back(DefaultValueSetter<T>(default_value)); default_value_setter_.push_back(DefaultValueSetter<T>(default_value));
return *this; return *this;
} }
...@@ -308,8 +324,10 @@ class TypedAttrChecker { ...@@ -308,8 +324,10 @@ class TypedAttrChecker {
auto it = attr_map->find(attr_name_); auto it = attr_map->find(attr_name_);
if (it == attr_map->end()) { if (it == attr_map->end()) {
// user do not set this attr // user do not set this attr
PADDLE_ENFORCE(!default_value_setter_.empty(), PADDLE_ENFORCE_EQ(
"Attribute '%s' is required!", attr_name_); default_value_setter_.empty(), false,
platform::errors::InvalidArgument(
"Attribute (%s) is not set correctly.", attr_name_));
// default_value_setter_ has no more than one element // default_value_setter_ has no more than one element
attr_map->emplace(attr_name_, default_value_setter_[0]()); attr_map->emplace(attr_name_, default_value_setter_[0]());
} }
......
...@@ -23,11 +23,14 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place, ...@@ -23,11 +23,14 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
in.place().which(), dst_place.which(), in.place().which(), dst_place.which(),
"Currently, model parallelism is only supported between CPU and CUDA"); platform::errors::Unavailable("Currently, model parallelism is only "
"supported between CPU and CUDA."));
// NOTE(yy): TransDataDevice should wait for computation of input. // NOTE(yy): TransDataDevice should wait for computation of input.
platform::DeviceContextPool::Instance().Get(in.place())->Wait(); if (!platform::is_cuda_pinned_place(in.place())) {
platform::DeviceContextPool::Instance().Get(dst_place)->Wait(); platform::DeviceContextPool::Instance().Get(in.place())->Wait();
platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
}
// FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and // FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and
// the enforced checkings have been done in GetDeviceContext, so the // the enforced checkings have been done in GetDeviceContext, so the
......
...@@ -133,11 +133,14 @@ bool DataFeed::PickOneFile(std::string* filename) { ...@@ -133,11 +133,14 @@ bool DataFeed::PickOneFile(std::string* filename) {
} }
void DataFeed::CheckInit() { void DataFeed::CheckInit() {
PADDLE_ENFORCE(finish_init_, "Initialization did not succeed."); PADDLE_ENFORCE_EQ(finish_init_, true, platform::errors::PreconditionNotMet(
"DataFeed initialization failed."));
} }
void DataFeed::CheckSetFileList() { void DataFeed::CheckSetFileList() {
PADDLE_ENFORCE(finish_set_filelist_, "Set filelist did not succeed."); PADDLE_ENFORCE_EQ(
finish_set_filelist_, true,
platform::errors::PreconditionNotMet("DataFeed set filelist failed."));
} }
void DataFeed::CheckStart() { void DataFeed::CheckStart() {
...@@ -160,14 +163,18 @@ void DataFeed::CopyToFeedTensor(void* dst, const void* src, size_t size) { ...@@ -160,14 +163,18 @@ void DataFeed::CopyToFeedTensor(void* dst, const void* src, size_t size) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice); cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice);
#else #else
PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option"); PADDLE_THROW(platform::errors::Unimplemented(
"Not supported GPU, please compile with option WITH_GPU=ON."));
#endif #endif
} }
} }
template <typename T> template <typename T>
void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) { void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
PADDLE_ENFORCE(queue_size > 0, "Illegal queue size: %d.", queue_size); PADDLE_ENFORCE_GT(
queue_size, 0,
platform::errors::InvalidArgument(
"Queue size %d is illegal in PrivateQueueDataFeed.", queue_size));
queue_size_ = queue_size; queue_size_ = queue_size;
queue_ = paddle::framework::MakeChannel<T>(); queue_ = paddle::framework::MakeChannel<T>();
queue_->SetCapacity(queue_size); queue_->SetCapacity(queue_size);
...@@ -418,8 +425,10 @@ void MultiSlotDataFeed::Init( ...@@ -418,8 +425,10 @@ void MultiSlotDataFeed::Init(
finish_set_filelist_ = false; finish_set_filelist_ = false;
finish_start_ = false; finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(), PADDLE_ENFORCE_EQ(
"Multi_slot_desc has not been set."); data_feed_desc.has_multi_slot_desc(), true,
platform::errors::PreconditionNotMet(
"Multi_slot_desc has not been set in MultiSlotDataFeed."));
paddle::framework::MultiSlotDesc multi_slot_desc = paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc(); data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size()); SetBatchSize(data_feed_desc.batch_size());
...@@ -668,13 +677,14 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) { ...@@ -668,13 +677,14 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
for (size_t i = 0; i < use_slots_index_.size(); ++i) { for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i]; int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10); int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE( PADDLE_ENFORCE_NE(
num, num, 0,
"The number of ids can not be zero, you need padding " platform::errors::InvalidArgument(
"it in data generator; or if there is something wrong with " "The number of ids can not be zero, you need padding "
"the data, please check if the data contains unresolvable " "it in data generator; or if there is something wrong with "
"characters.\nplease check this error line: %s", "the data, please check if the data contains unresolvable "
str); "characters.\nplease check this error line: %s.",
str));
if (idx != -1) { if (idx != -1) {
(*instance)[idx].Init(all_slots_type_[i]); (*instance)[idx].Init(all_slots_type_[i]);
...@@ -765,8 +775,10 @@ void MultiSlotInMemoryDataFeed::Init( ...@@ -765,8 +775,10 @@ void MultiSlotInMemoryDataFeed::Init(
finish_set_filelist_ = false; finish_set_filelist_ = false;
finish_start_ = false; finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(), PADDLE_ENFORCE_EQ(
"Multi_slot_desc has not been set."); data_feed_desc.has_multi_slot_desc(), true,
platform::errors::PreconditionNotMet(
"Multi_slot_desc has not been set in MultiSlotInMemoryDataFeed."));
paddle::framework::MultiSlotDesc multi_slot_desc = paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc(); data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size()); SetBatchSize(data_feed_desc.batch_size());
...@@ -898,13 +910,14 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { ...@@ -898,13 +910,14 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
for (size_t i = 0; i < use_slots_index_.size(); ++i) { for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i]; int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10); int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE( PADDLE_ENFORCE_NE(
num, num, 0,
"The number of ids can not be zero, you need padding " platform::errors::InvalidArgument(
"it in data generator; or if there is something wrong with " "The number of ids can not be zero, you need padding "
"the data, please check if the data contains unresolvable " "it in data generator; or if there is something wrong with "
"characters.\nplease check this error line: %s", "the data, please check if the data contains unresolvable "
str); "characters.\nplease check this error line: %s.",
str));
if (idx != -1) { if (idx != -1) {
if (all_slots_type_[i][0] == 'f') { // float if (all_slots_type_[i][0] == 'f') { // float
for (int j = 0; j < num; ++j) { for (int j = 0; j < num; ++j) {
...@@ -963,13 +976,14 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) { ...@@ -963,13 +976,14 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) {
for (size_t i = 0; i < use_slots_index_.size(); ++i) { for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i]; int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10); int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE( PADDLE_ENFORCE_NE(
num, num, 0,
"The number of ids can not be zero, you need padding " platform::errors::InvalidArgument(
"it in data generator; or if there is something wrong with " "The number of ids can not be zero, you need padding "
"the data, please check if the data contains unresolvable " "it in data generator; or if there is something wrong with "
"characters.\nplease check this error line: %s", "the data, please check if the data contains unresolvable "
str); "characters.\nplease check this error line: %s.",
str));
if (idx != -1) { if (idx != -1) {
if (all_slots_type_[i][0] == 'f') { // float if (all_slots_type_[i][0] == 'f') { // float
...@@ -1085,7 +1099,7 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec( ...@@ -1085,7 +1099,7 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
PADDLE_ENFORCE_EQ(slot_offset.size(), 2, PADDLE_ENFORCE_EQ(slot_offset.size(), 2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"In batch reader, the sparse tensor lod size " "In batch reader, the sparse tensor lod size "
"must be 2, but received %d", "must be 2, but received %d.",
slot_offset.size())); slot_offset.size()));
const auto& max_size = slot_offset[1]; const auto& max_size = slot_offset[1];
tmp_offset.reserve(max_size + 1); tmp_offset.reserve(max_size + 1);
...@@ -1137,10 +1151,13 @@ void PrivateInstantDataFeed<T>::PutToFeedVec() { ...@@ -1137,10 +1151,13 @@ void PrivateInstantDataFeed<T>::PutToFeedVec() {
for (const auto e : use_slots_shape_[i]) { for (const auto e : use_slots_shape_[i]) {
total_dims *= e; total_dims *= e;
} }
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
total_dims == total_instance, total_dims, total_instance,
"The actual data size of slot[%s] doesn't match its declaration", platform::errors::InvalidArgument(
use_slots_[i].c_str()); "The actual data size of slot[%s] doesn't match its declaration. "
"The actual data size of slot is %lld"
", and its declaration is %lld.",
use_slots_[i].c_str(), total_dims, total_instance));
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i])); feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
} }
} }
...@@ -1162,7 +1179,9 @@ int PrivateInstantDataFeed<T>::Next() { ...@@ -1162,7 +1179,9 @@ int PrivateInstantDataFeed<T>::Next() {
return -1; return -1;
} }
PADDLE_ENFORCE(true == ParseOneMiniBatch(), "Fail to parse mini-batch data"); PADDLE_ENFORCE_EQ(
true, ParseOneMiniBatch(),
platform::errors::InvalidArgument("Fail to parse mini-batch data."));
PutToFeedVec(); PutToFeedVec();
return ins_vec_[0].GetBatchSize(); return ins_vec_[0].GetBatchSize();
} }
...@@ -1173,8 +1192,10 @@ void PrivateInstantDataFeed<T>::Init(const DataFeedDesc& data_feed_desc) { ...@@ -1173,8 +1192,10 @@ void PrivateInstantDataFeed<T>::Init(const DataFeedDesc& data_feed_desc) {
finish_set_filelist_ = false; finish_set_filelist_ = false;
finish_start_ = false; finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(), PADDLE_ENFORCE_EQ(
"Multi_slot_desc has not been set."); data_feed_desc.has_multi_slot_desc(), true,
platform::errors::PreconditionNotMet(
"Multi_slot_desc has not been set in PrivateInstantDataFeed."));
paddle::framework::MultiSlotDesc multi_slot_desc = paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc(); data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size()); SetBatchSize(data_feed_desc.batch_size());
...@@ -1217,7 +1238,10 @@ template class PrivateInstantDataFeed<std::vector<MultiSlotType>>; ...@@ -1217,7 +1238,10 @@ template class PrivateInstantDataFeed<std::vector<MultiSlotType>>;
bool MultiSlotFileInstantDataFeed::Preprocess(const std::string& filename) { bool MultiSlotFileInstantDataFeed::Preprocess(const std::string& filename) {
fd_ = open(filename.c_str(), O_RDONLY); fd_ = open(filename.c_str(), O_RDONLY);
PADDLE_ENFORCE(fd_ != -1, "Fail to open file: %s", filename.c_str()); PADDLE_ENFORCE_NE(
fd_, -1, platform::errors::Unavailable(
"Fail to open file: %s in MultiSlotFileInstantDataFeed.",
filename.c_str()));
struct stat sb; struct stat sb;
fstat(fd_, &sb); fstat(fd_, &sb);
...@@ -1225,7 +1249,11 @@ bool MultiSlotFileInstantDataFeed::Preprocess(const std::string& filename) { ...@@ -1225,7 +1249,11 @@ bool MultiSlotFileInstantDataFeed::Preprocess(const std::string& filename) {
buffer_ = buffer_ =
reinterpret_cast<char*>(mmap(NULL, end_, PROT_READ, MAP_PRIVATE, fd_, 0)); reinterpret_cast<char*>(mmap(NULL, end_, PROT_READ, MAP_PRIVATE, fd_, 0));
PADDLE_ENFORCE(buffer_ != MAP_FAILED, strerror(errno)); PADDLE_ENFORCE_NE(
buffer_, MAP_FAILED,
platform::errors::Unavailable(
"Memory map failed when create shared memory, error number is %s.",
strerror(errno)));
offset_ = 0; offset_ = 0;
return true; return true;
...@@ -1257,12 +1285,13 @@ bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() { ...@@ -1257,12 +1285,13 @@ bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() {
char type = all_slots_type_[i][0]; char type = all_slots_type_[i][0];
uint16_t num = *reinterpret_cast<uint16_t*>(buffer_ + offset_); uint16_t num = *reinterpret_cast<uint16_t*>(buffer_ + offset_);
PADDLE_ENFORCE( PADDLE_ENFORCE_NE(
num, num, 0,
"The number of ids can not be zero, you need padding " platform::errors::InvalidArgument(
"it in data generator; or if there is something wrong with " "The number of ids can not be zero, you need padding "
"the data, please check if the data contains unresolvable " "it in data generator; or if there is something wrong with "
"characters."); "the data, please check if the data contains unresolvable "
"characters."));
offset_ += sizeof(uint16_t); offset_ += sizeof(uint16_t);
if (idx != -1) { if (idx != -1) {
...@@ -1304,7 +1333,12 @@ bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() { ...@@ -1304,7 +1333,12 @@ bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() {
} }
PADDLE_ENFORCE(batch_size_ == default_batch_size_ || offset_ == end_, PADDLE_ENFORCE(batch_size_ == default_batch_size_ || offset_ == end_,
"offset_ != end_"); platform::errors::InvalidArgument(
"The batch size id not equal to default batch size, or "
"the offset is not equal to end index."
"The batch size is %d, default batcch size is %d, offset "
"is %d, end index is %d.",
batch_size_, default_batch_size_, offset_, end_));
return true; return true;
} }
#endif #endif
......
...@@ -116,7 +116,8 @@ class DataFeed { ...@@ -116,7 +116,8 @@ class DataFeed {
virtual ~DataFeed() {} virtual ~DataFeed() {}
virtual void Init(const DataFeedDesc& data_feed_desc) = 0; virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
virtual bool CheckFile(const char* filename) { virtual bool CheckFile(const char* filename) {
PADDLE_THROW("This function(CheckFile) is not implemented."); PADDLE_THROW(platform::errors::Unimplemented(
"This function(CheckFile) is not implemented."));
} }
// Set filelist for DataFeed. // Set filelist for DataFeed.
// Pay attention that it must init all readers before call this function. // Pay attention that it must init all readers before call this function.
...@@ -179,7 +180,8 @@ class DataFeed { ...@@ -179,7 +180,8 @@ class DataFeed {
} }
virtual int GetCurBatchSize() { return batch_size_; } virtual int GetCurBatchSize() { return batch_size_; }
virtual void LoadIntoMemory() { virtual void LoadIntoMemory() {
PADDLE_THROW("This function(LoadIntoMemory) is not implemented."); PADDLE_THROW(platform::errors::Unimplemented(
"This function(LoadIntoMemory) is not implemented."));
} }
virtual void SetPlace(const paddle::platform::Place& place) { virtual void SetPlace(const paddle::platform::Place& place) {
place_ = place; place_ = place;
...@@ -438,14 +440,23 @@ class MultiSlotType { ...@@ -438,14 +440,23 @@ class MultiSlotType {
private: private:
void CheckType(const std::string& type) const { void CheckType(const std::string& type) const {
PADDLE_ENFORCE((type == "uint64") || (type == "float"), PADDLE_ENFORCE_EQ((type == "uint64" || type == "float"), true,
"There is no this type<%s>.", type); platform::errors::InvalidArgument(
"MultiSlotType error, expect type is uint64 or "
"float, but received type is %s.",
type));
} }
void CheckFloat() const { void CheckFloat() const {
PADDLE_ENFORCE(type_[0] == 'f', "Add %s value to float slot.", type_); PADDLE_ENFORCE_EQ(
type_[0], 'f',
platform::errors::InvalidArgument(
"MultiSlotType error, add %s value to float slot.", type_));
} }
void CheckUint64() const { void CheckUint64() const {
PADDLE_ENFORCE(type_[0] == 'u', "Add %s value to uint64 slot.", type_); PADDLE_ENFORCE_EQ(
type_[0], 'u',
platform::errors::InvalidArgument(
"MultiSlotType error, add %s value to uint64 slot.", type_));
} }
std::vector<float> float_feasign_; std::vector<float> float_feasign_;
std::vector<uint64_t> uint64_feasign_; std::vector<uint64_t> uint64_feasign_;
......
...@@ -34,8 +34,10 @@ paddle::framework::DataFeedDesc load_datafeed_param_from_file( ...@@ -34,8 +34,10 @@ paddle::framework::DataFeedDesc load_datafeed_param_from_file(
const char* filename) { const char* filename) {
paddle::framework::DataFeedDesc data_feed_desc; paddle::framework::DataFeedDesc data_feed_desc;
int file_descriptor = open(filename, O_RDONLY); int file_descriptor = open(filename, O_RDONLY);
PADDLE_ENFORCE_NE(file_descriptor, -1, platform::errors::Unavailable( PADDLE_ENFORCE_NE(
"Cannot open file %s.", filename)); file_descriptor, -1,
platform::errors::Unavailable(
"Cannot open file %s c load datafeed param from file.", filename));
google::protobuf::io::FileInputStream fileInput(file_descriptor); google::protobuf::io::FileInputStream fileInput(file_descriptor);
google::protobuf::TextFormat::Parse(&fileInput, &data_feed_desc); google::protobuf::TextFormat::Parse(&fileInput, &data_feed_desc);
close(file_descriptor); close(file_descriptor);
...@@ -45,8 +47,10 @@ paddle::framework::DataFeedDesc load_datafeed_param_from_file( ...@@ -45,8 +47,10 @@ paddle::framework::DataFeedDesc load_datafeed_param_from_file(
const std::vector<std::string> load_filelist_from_file(const char* filename) { const std::vector<std::string> load_filelist_from_file(const char* filename) {
std::vector<std::string> filelist; std::vector<std::string> filelist;
std::ifstream fin(filename); std::ifstream fin(filename);
PADDLE_ENFORCE_EQ(fin.good(), true, platform::errors::Unavailable( PADDLE_ENFORCE_EQ(
"Cannot open file %s.", filename)); fin.good(), true,
platform::errors::Unavailable(
"Cannot open file %s when load filelist from file.", filename));
std::string line; std::string line;
while (getline(fin, line)) { while (getline(fin, line)) {
filelist.push_back(line); filelist.push_back(line);
...@@ -196,7 +200,8 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set, ...@@ -196,7 +200,8 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
} }
} }
} else { } else {
PADDLE_THROW("Error type in proto file."); PADDLE_THROW(platform::errors::InvalidArgument(
"Error type in proto file."));
} }
} else { // sparse branch } else { // sparse branch
if (slot.type() == "uint64") { if (slot.type() == "uint64") {
...@@ -218,7 +223,8 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set, ...@@ -218,7 +223,8 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
} }
} }
} else { } else {
PADDLE_THROW("Error type in proto file."); PADDLE_THROW(platform::errors::InvalidArgument(
"Error type in proto file."));
} }
} // end sparse branch } // end sparse branch
++index; ++index;
...@@ -272,7 +278,10 @@ void GetElemSetFromFile(std::vector<MultiTypeSet>* file_elem_set, ...@@ -272,7 +278,10 @@ void GetElemSetFromFile(std::vector<MultiTypeSet>* file_elem_set,
file_elem_set->resize(used_slot_num); file_elem_set->resize(used_slot_num);
for (const auto& file : filelist) { for (const auto& file : filelist) {
std::ifstream fin(file.c_str()); std::ifstream fin(file.c_str());
PADDLE_ENFORCE(fin.good(), "Can not open %s.", file.c_str()); PADDLE_ENFORCE_EQ(
fin.good(), true,
platform::errors::Unavailable(
"Can not open %s when get element set from file.", file.c_str()));
while (1) { while (1) {
bool end_flag = false; bool end_flag = false;
int index = 0; int index = 0;
...@@ -298,7 +307,8 @@ void GetElemSetFromFile(std::vector<MultiTypeSet>* file_elem_set, ...@@ -298,7 +307,8 @@ void GetElemSetFromFile(std::vector<MultiTypeSet>* file_elem_set,
} }
} }
} else { } else {
PADDLE_THROW("Error type in proto file."); PADDLE_THROW(
platform::errors::InvalidArgument("Error type in proto file."));
} }
if (slot.is_used()) { if (slot.is_used()) {
++index; ++index;
......
...@@ -45,7 +45,8 @@ inline DataLayout StringToDataLayout(const std::string& str) { ...@@ -45,7 +45,8 @@ inline DataLayout StringToDataLayout(const std::string& str) {
} else if (s == "MKLDNNLAYOUT") { } else if (s == "MKLDNNLAYOUT") {
return DataLayout::kMKLDNN; return DataLayout::kMKLDNN;
} else { } else {
PADDLE_THROW("Unknown storage order string: %s", s); PADDLE_THROW(platform::errors::InvalidArgument(
"Unknown data layout type string: %s.", s));
} }
} }
...@@ -60,7 +61,8 @@ inline std::string DataLayoutToString(const DataLayout& data_layout) { ...@@ -60,7 +61,8 @@ inline std::string DataLayoutToString(const DataLayout& data_layout) {
case DataLayout::kMKLDNN: case DataLayout::kMKLDNN:
return "MKLDNNLAYOUT"; return "MKLDNNLAYOUT";
default: default:
PADDLE_THROW("unknown DataLayout %d", data_layout); PADDLE_THROW(platform::errors::InvalidArgument(
"Unknown Data Layout type %d.", data_layout));
} }
} }
......
...@@ -25,14 +25,17 @@ namespace paddle { ...@@ -25,14 +25,17 @@ namespace paddle {
namespace framework { namespace framework {
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to) { std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to) {
PADDLE_ENFORCE_NE(from, to, PADDLE_ENFORCE_NE(
"layout transform should transform different layout"); from, to,
platform::errors::InvalidArgument(
"Layout transform should transform between different layout."));
if (from == DataLayout::kNCHW && to == DataLayout::kNHWC) { if (from == DataLayout::kNCHW && to == DataLayout::kNHWC) {
return {0, 2, 3, 1}; return {0, 2, 3, 1};
} else if (from == DataLayout::kNHWC && to == DataLayout::kNCHW) { } else if (from == DataLayout::kNHWC && to == DataLayout::kNCHW) {
return {0, 3, 1, 2}; return {0, 3, 1, 2};
} else { } else {
PADDLE_THROW("unsupported transform"); PADDLE_THROW(
platform::errors::InvalidArgument("Unsupported layout transform."));
} }
} }
...@@ -55,7 +58,8 @@ struct CastDataLayout { ...@@ -55,7 +58,8 @@ struct CastDataLayout {
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_); auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
trans4(*context, in_, out_, axis_); trans4(*context, in_, out_, axis_);
} else { } else {
PADDLE_THROW("Unsupport CPU <-> GPU!"); PADDLE_THROW(platform::errors::PreconditionNotMet(
"Unsupported data layout cast from CPU to GPU."));
} }
} }
}; };
...@@ -66,9 +70,14 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var, ...@@ -66,9 +70,14 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
PADDLE_ENFORCE( PADDLE_ENFORCE(
platform::places_are_same_class(kernel_type_for_var.place_, platform::places_are_same_class(kernel_type_for_var.place_,
expected_kernel_type.place_), expected_kernel_type.place_),
"TransDataLayout only support DataLayout transform on same place!"); platform::errors::PreconditionNotMet(
"TransDataLayout only support DataLayout transform on same place."));
PADDLE_ENFORCE(arity(in.dims()) == 4, "Input Arity only support 4!"); PADDLE_ENFORCE_EQ(
arity(in.dims()), 4,
platform::errors::InvalidArgument(
"Input dimension arity only can be 4, the input dimension is %s.",
in.dims()));
auto& pool = platform::DeviceContextPool::Instance(); auto& pool = platform::DeviceContextPool::Instance();
...@@ -108,7 +117,8 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) { ...@@ -108,7 +117,8 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
case mkldnn::memory::data_type::s32: case mkldnn::memory::data_type::s32:
return platform::to_void_cast(tensor.data<int32_t>()); return platform::to_void_cast(tensor.data<int32_t>());
default: default:
PADDLE_THROW("wrong mkldnn type provided"); PADDLE_THROW(
platform::errors::InvalidArgument("Wrong mkldnn type provided."));
} }
} }
...@@ -121,8 +131,9 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -121,8 +131,9 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
PADDLE_ENFORCE( PADDLE_ENFORCE(
in_layout == DataLayout::kMKLDNN && out_layout != DataLayout::kMKLDNN, in_layout == DataLayout::kMKLDNN && out_layout != DataLayout::kMKLDNN,
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to " platform::errors::InvalidArgument(
"non-MKLDNN"); "TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"non-MKLDNN"));
innerTransDataLayoutFromMKLDNN( innerTransDataLayoutFromMKLDNN(
in_layout, in_layout,
...@@ -155,7 +166,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, ...@@ -155,7 +166,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
memory::data_type in_type = ToMKLDNNDataType(in.type()); memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE_NE(in_type, memory::data_type::undef, PADDLE_ENFORCE_NE(in_type, memory::data_type::undef,
"Input tensor type is not supported: %s", in.type()); platform::errors::InvalidArgument(
"Input tensor type (%s) is not supported.",
DataTypeToString(in.type())));
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format()); auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
auto out_format = auto out_format =
......
...@@ -38,8 +38,9 @@ inline MKLDNNMemoryFormat ToMKLDNNFormat(const DataLayout& layout) { ...@@ -38,8 +38,9 @@ inline MKLDNNMemoryFormat ToMKLDNNFormat(const DataLayout& layout) {
case DataLayout::kNCHW: case DataLayout::kNCHW:
return MKLDNNMemoryFormat::nchw; return MKLDNNMemoryFormat::nchw;
default: default:
PADDLE_THROW("Fail to convert layout %s to MKLDNN format", PADDLE_THROW(platform::errors::InvalidArgument(
DataLayoutToString(layout)); "Fail to convert layout %s to MKLDNN format.",
DataLayoutToString(layout)));
} }
} }
...@@ -50,7 +51,8 @@ inline DataLayout ToPaddleLayout(const MKLDNNMemoryFormat& format) { ...@@ -50,7 +51,8 @@ inline DataLayout ToPaddleLayout(const MKLDNNMemoryFormat& format) {
case MKLDNNMemoryFormat::nchw: case MKLDNNMemoryFormat::nchw:
return DataLayout::kNCHW; return DataLayout::kNCHW;
default: default:
PADDLE_THROW("Fail to convert MKLDNN format to paddle layout"); PADDLE_THROW(platform::errors::InvalidArgument(
"Fail to convert MKLDNN format to paddle layout."));
} }
} }
......
...@@ -45,9 +45,10 @@ void TransformData(const OpKernelType &expected_kernel_type, ...@@ -45,9 +45,10 @@ void TransformData(const OpKernelType &expected_kernel_type,
if (NeedTransformLayout(lout, lin)) { if (NeedTransformLayout(lout, lin)) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (lin == DataLayout::kMKLDNN || lout == DataLayout::kMKLDNN) { if (lin == DataLayout::kMKLDNN || lout == DataLayout::kMKLDNN) {
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
!(lin == DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN), !(lin == DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN), true,
"No layout transform needed between two MKLDNN OPKernels"); platform::errors::PreconditionNotMet(
"No layout transform needed between two MKLDNN OPKernels."));
if (lin != DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN) { if (lin != DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN) {
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel // Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
...@@ -96,7 +97,10 @@ void TransformData(const OpKernelType &expected_kernel_type, ...@@ -96,7 +97,10 @@ void TransformData(const OpKernelType &expected_kernel_type,
PassTensorData(&out, &in); PassTensorData(&out, &in);
} }
PADDLE_ENFORCE(transformed, "No transform is applied, please check!"); PADDLE_ENFORCE_EQ(
transformed, true,
platform::errors::PreconditionNotMet(
"No transform is applied for the data needs to be transformed."));
// get output data // get output data
output_tensor->ShareDataWith(in); output_tensor->ShareDataWith(in);
} }
...@@ -116,7 +120,10 @@ void SetTensorToVariable(const Variable &in_var, const Tensor &tensor, ...@@ -116,7 +120,10 @@ void SetTensorToVariable(const Variable &in_var, const Tensor &tensor,
trans_selected_rows->set_rows(in_selected_rows.rows()); trans_selected_rows->set_rows(in_selected_rows.rows());
trans_selected_rows->mutable_value()->ShareDataWith(tensor); trans_selected_rows->mutable_value()->ShareDataWith(tensor);
} else { } else {
PADDLE_THROW("unknown var type"); PADDLE_THROW(platform::errors::Unavailable(
"Unsupported variable type, only supports LoDTensor or SelectedRows, "
"but the input variable type is %s.",
ToTypeName(in_var.Type())));
} }
} }
......
...@@ -65,7 +65,8 @@ proto::VarType::Type ToDataType(std::type_index type) { ...@@ -65,7 +65,8 @@ proto::VarType::Type ToDataType(std::type_index type) {
if (it != gDataTypeMap().cpp_to_proto_.end()) { if (it != gDataTypeMap().cpp_to_proto_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support %s as tensor type", type.name()); PADDLE_THROW(platform::errors::Unimplemented(
"Not support %s as tensor data type.", platform::demangle(type.name())));
} }
std::type_index ToTypeIndex(proto::VarType::Type type) { std::type_index ToTypeIndex(proto::VarType::Type type) {
...@@ -73,8 +74,9 @@ std::type_index ToTypeIndex(proto::VarType::Type type) { ...@@ -73,8 +74,9 @@ std::type_index ToTypeIndex(proto::VarType::Type type) {
if (it != gDataTypeMap().proto_to_cpp_.end()) { if (it != gDataTypeMap().proto_to_cpp_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type", PADDLE_THROW(platform::errors::Unimplemented(
static_cast<int>(type)); "Not support proto::VarType::Type(%d) as tensor type.",
static_cast<int>(type)));
} }
std::string DataTypeToString(const proto::VarType::Type type) { std::string DataTypeToString(const proto::VarType::Type type) {
...@@ -82,8 +84,9 @@ std::string DataTypeToString(const proto::VarType::Type type) { ...@@ -82,8 +84,9 @@ std::string DataTypeToString(const proto::VarType::Type type) {
if (it != gDataTypeMap().proto_to_str_.end()) { if (it != gDataTypeMap().proto_to_str_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type", PADDLE_THROW(platform::errors::Unimplemented(
static_cast<int>(type)); "Not support proto::VarType::Type(%d) as tensor type.",
static_cast<int>(type)));
} }
size_t SizeOfType(proto::VarType::Type type) { size_t SizeOfType(proto::VarType::Type type) {
...@@ -91,7 +94,8 @@ size_t SizeOfType(proto::VarType::Type type) { ...@@ -91,7 +94,8 @@ size_t SizeOfType(proto::VarType::Type type) {
if (it != gDataTypeMap().proto_to_size_.end()) { if (it != gDataTypeMap().proto_to_size_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support %s as tensor type", DataTypeToString(type)); PADDLE_THROW(platform::errors::Unimplemented("Not support %s as tensor type.",
DataTypeToString(type)));
} }
} // namespace framework } // namespace framework
......
...@@ -78,7 +78,9 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { ...@@ -78,7 +78,9 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
_ForEachDataType_(VisitDataTypeCallback); _ForEachDataType_(VisitDataTypeCallback);
#undef VisitDataTypeCallback #undef VisitDataTypeCallback
PADDLE_THROW("Not supported %d", type); PADDLE_THROW(platform::errors::Unimplemented(
"Not supported proto::VarType::Type(%d) as data type.",
static_cast<int>(type)));
} }
template <typename Visitor> template <typename Visitor>
......
...@@ -56,7 +56,8 @@ struct CastDataType { ...@@ -56,7 +56,8 @@ struct CastDataType {
context->Wait(); context->Wait();
#endif #endif
} else { } else {
PADDLE_THROW("Unsupported place!"); PADDLE_THROW(platform::errors::Unimplemented(
"Place type is not supported when casting data type."));
} }
} }
}; };
...@@ -98,7 +99,9 @@ void TransDataType(const OpKernelType& kernel_type_for_var, ...@@ -98,7 +99,9 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx)); framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
break; break;
default: default:
PADDLE_THROW("Not support type %d", src_type); PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when casting data type.",
DataTypeToString(src_type)));
} }
} }
......
...@@ -81,9 +81,11 @@ bool contain_unknown_dim(const DDim& ddim) { ...@@ -81,9 +81,11 @@ bool contain_unknown_dim(const DDim& ddim) {
} }
DDim slice_ddim(const DDim& dim, int begin, int end) { DDim slice_ddim(const DDim& dim, int begin, int end) {
PADDLE_ENFORCE(begin >= 0 && end <= dim.size(), PADDLE_ENFORCE_EQ(
"[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.", (begin >= 0 && end <= dim.size()), true,
begin, end, dim.size()); platform::errors::InvalidArgument(
"[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.", begin,
end, dim.size()));
// Constructor of DDim would check whether end - begin is valid // Constructor of DDim would check whether end - begin is valid
return DDim(dim.Get() + begin, end - begin); return DDim(dim.Get() + begin, end - begin);
} }
......
...@@ -29,20 +29,23 @@ namespace framework { ...@@ -29,20 +29,23 @@ namespace framework {
return (callback); \ return (callback); \
} }
#define PADDLE_VISIT_DDIM(rank, callback) \ #define PADDLE_VISIT_DDIM(rank, callback) \
switch (rank) { \ switch (rank) { \
PADDLE_VISIT_DDIM_BASE(0, callback); \ PADDLE_VISIT_DDIM_BASE(0, callback); \
PADDLE_VISIT_DDIM_BASE(1, callback); \ PADDLE_VISIT_DDIM_BASE(1, callback); \
PADDLE_VISIT_DDIM_BASE(2, callback); \ PADDLE_VISIT_DDIM_BASE(2, callback); \
PADDLE_VISIT_DDIM_BASE(3, callback); \ PADDLE_VISIT_DDIM_BASE(3, callback); \
PADDLE_VISIT_DDIM_BASE(4, callback); \ PADDLE_VISIT_DDIM_BASE(4, callback); \
PADDLE_VISIT_DDIM_BASE(5, callback); \ PADDLE_VISIT_DDIM_BASE(5, callback); \
PADDLE_VISIT_DDIM_BASE(6, callback); \ PADDLE_VISIT_DDIM_BASE(6, callback); \
PADDLE_VISIT_DDIM_BASE(7, callback); \ PADDLE_VISIT_DDIM_BASE(7, callback); \
PADDLE_VISIT_DDIM_BASE(8, callback); \ PADDLE_VISIT_DDIM_BASE(8, callback); \
PADDLE_VISIT_DDIM_BASE(9, callback); \ PADDLE_VISIT_DDIM_BASE(9, callback); \
default: \ default: \
PADDLE_THROW("Invalid rank %d", rank); \ PADDLE_THROW(platform::errors::Unimplemented( \
"Invalid dimension to be accessed. Now only supports access to " \
"dimension 0 to 9, but received dimension is %d.", \
rank)); \
} }
template <typename T1, typename T2> template <typename T1, typename T2>
...@@ -92,13 +95,31 @@ class DDim { ...@@ -92,13 +95,31 @@ class DDim {
inline int64_t operator[](int idx) const { return dim_[idx]; } inline int64_t operator[](int idx) const { return dim_[idx]; }
inline int64_t& at(int idx) { int64_t& at(int idx) {
PADDLE_ENFORCE(idx >= 0 && idx < rank_, "Invalid idx %d", idx); PADDLE_ENFORCE_GE(idx, 0,
platform::errors::InvalidArgument(
"Invalid DDim index to be accessed. The valid index "
"is between 0 and %d, but received index is %d.",
rank_, idx));
PADDLE_ENFORCE_LT(idx, rank_,
platform::errors::InvalidArgument(
"Invalid DDim index to be accessed. The valid index "
"is between 0 and %d, but received index is %d.",
rank_, idx));
return dim_[idx]; return dim_[idx];
} }
inline int64_t at(int idx) const { int64_t at(int idx) const {
PADDLE_ENFORCE(idx >= 0 && idx < rank_, "Invalid idx %d", idx); PADDLE_ENFORCE_GE(idx, 0,
platform::errors::InvalidArgument(
"Invalid DDim index to be accessed. The valid index "
"is between 0 and %d, but received index is %d.",
rank_, idx));
PADDLE_ENFORCE_LT(idx, rank_,
platform::errors::InvalidArgument(
"Invalid DDim index to be accessed. The valid index "
"is between 0 and %d, but received index is %d.",
rank_, idx));
return dim_[idx]; return dim_[idx];
} }
......
...@@ -42,53 +42,18 @@ inline void InitVarsInScope(const std::vector<VarInfo> &var_infos, Scope *scope, ...@@ -42,53 +42,18 @@ inline void InitVarsInScope(const std::vector<VarInfo> &var_infos, Scope *scope,
} }
} }
// get RpcContext and remote send and recv op // get CommContext and remote send and recv op
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
using RpcCtxMap = operators::distributed::RpcCtxMap;
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx;
for (auto &node : graphs[0]->Nodes()) {
VLOG(3) << "node name " << node->Name();
if (node && node->IsOp()) {
if (node->Name() == "send") {
auto send_var_name = node->Op()->Input("X")[0];
auto send_varnames =
BOOST_GET_CONST(std::vector<std::string>,
node->Op()->GetNullableAttr("send_varnames"));
auto epmap = BOOST_GET_CONST(std::vector<std::string>,
node->Op()->GetNullableAttr("epmap"));
auto height_section = BOOST_GET_CONST(
std::vector<int64_t>, node->Op()->GetNullableAttr("sections"));
auto trainer_id =
BOOST_GET_CONST(int, node->Op()->GetNullableAttr("trainer_id"));
auto merge_add =
BOOST_GET_CONST(bool, node->Op()->GetNullableAttr("merge_add"));
if (!merge_add) {
merge_add = FLAGS_communicator_is_sgd_optimizer;
}
auto use_send_handler = BOOST_GET_CONST(
bool, node->Op()->GetNullableAttr("use_send_handler"));
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
send_var_name, send_varnames, epmap, height_section, trainer_id,
merge_add, use_send_handler);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
}
}
}
// init communicator here // init communicator here
if (send_varname_to_ctx.size() > 0) { auto *instance = operators::distributed::Communicator::GetInstance();
auto *instance = operators::distributed::Communicator::GetInstance(); auto initialized = instance ? true : false;
auto initialized = instance ? true : false; PADDLE_ENFORCE_EQ(initialized, true,
PADDLE_ENFORCE_EQ(initialized, true, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "Communicator is not Initialized, you may use "
"Communicator is not Initialized, you may use " "FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/"
"FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/" "develop/markdown_doc/transpiler)"));
"develop/markdown_doc/transpiler)"));
}
#endif #endif
} }
......
...@@ -107,21 +107,31 @@ class ExceptionHolder { ...@@ -107,21 +107,31 @@ class ExceptionHolder {
type_ = kNone; type_ = kNone;
} }
// NOTE: currently in PE, multiple exceptions may occured in multiple
// threads, and the exception that occur later will overwrite that
// occur earlier, but what we want should be the first triggered exception.
// However, EOF exception is lower priority exception and can be overwritten,
// but other exceptions should not be prioritized.
void Catch(const platform::EnforceNotMet& exp) { void Catch(const platform::EnforceNotMet& exp) {
std::lock_guard<std::mutex> lock(mu_); std::lock_guard<std::mutex> lock(mu_);
exception_.reset(new platform::EnforceNotMet(exp)); if (exception_.get() == nullptr || type_ == kEOF) {
type_ = kEnforceNotMet; exception_.reset(new platform::EnforceNotMet(exp));
type_ = kEnforceNotMet;
} else {
VLOG(2) << "Non-first exception is discarded, the error message is"
<< exception_->what();
}
} }
void Catch(const memory::allocation::BadAlloc& exp) { void Catch(const memory::allocation::BadAlloc& exp) {
std::lock_guard<std::mutex> lock(mu_); std::lock_guard<std::mutex> lock(mu_);
// BadAlloc have the highest priority if (exception_.get() == nullptr || type_ == kEOF) {
if (exception_.get() != nullptr) { exception_.reset(new paddle::memory::allocation::BadAlloc(exp));
VLOG(2) << "exception is reset by BadAlloc, the original error message is" type_ = kBadAlloc;
} else {
VLOG(2) << "Non-first exception is discarded, the error message is"
<< exception_->what(); << exception_->what();
} }
exception_.reset(new paddle::memory::allocation::BadAlloc(exp));
type_ = kBadAlloc;
} }
void Catch(const platform::EOFException& exp) { void Catch(const platform::EOFException& exp) {
...@@ -138,10 +148,12 @@ class ExceptionHolder { ...@@ -138,10 +148,12 @@ class ExceptionHolder {
void Catch(const std::exception& exp) { void Catch(const std::exception& exp) {
std::lock_guard<std::mutex> lock(mu_); std::lock_guard<std::mutex> lock(mu_);
// std::exception will not cover anything if (exception_.get() == nullptr || type_ == kEOF) {
if (exception_.get() == nullptr) {
exception_.reset(new std::exception(exp)); exception_.reset(new std::exception(exp));
type_ = kBaseException; type_ = kBaseException;
} else {
VLOG(2) << "Non-first exception is discarded, the error message is"
<< exception_->what();
} }
} }
......
...@@ -24,6 +24,29 @@ namespace details { ...@@ -24,6 +24,29 @@ namespace details {
namespace f = paddle::framework; namespace f = paddle::framework;
namespace p = paddle::platform; namespace p = paddle::platform;
TEST(ExceptionHolderTester, TestEnforceNotMetCatch) {
ExceptionHolder exception_holder;
try {
throw platform::EnforceNotMet("enforce not met test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "EnforceNotMet");
bool catch_enforce_not_met = false;
try {
exception_holder.ReThrow();
} catch (platform::EnforceNotMet& ex) {
catch_enforce_not_met = true;
} catch (...) {
catch_enforce_not_met = false;
}
ASSERT_TRUE(catch_enforce_not_met);
}
TEST(ExceptionHolderTester, TestBadAllocCatch) { TEST(ExceptionHolderTester, TestBadAllocCatch) {
ExceptionHolder exception_holder; ExceptionHolder exception_holder;
...@@ -70,15 +93,24 @@ TEST(ExceptionHolderTester, TestBaseExpceptionCatch) { ...@@ -70,15 +93,24 @@ TEST(ExceptionHolderTester, TestBaseExpceptionCatch) {
ASSERT_TRUE(catch_base_exception); ASSERT_TRUE(catch_base_exception);
} }
TEST(ExceptionHolderTester, TestBadAllocCatchReplace) { TEST(ExceptionHolderTester, TestExceptionReplace) {
ExceptionHolder exception_holder; ExceptionHolder exception_holder;
try {
throw platform::EnforceNotMet("enforce not met test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "EnforceNotMet");
try { try {
throw std::exception(); throw std::exception();
} catch (...) { } catch (...) {
exception_holder.Catch(std::current_exception()); exception_holder.Catch(std::current_exception());
} }
ASSERT_TRUE(exception_holder.IsCaught()); ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "BaseException"); ASSERT_EQ(exception_holder.Type(), "EnforceNotMet");
try { try {
throw memory::allocation::BadAlloc("bad alloc test", "test_file", 0); throw memory::allocation::BadAlloc("bad alloc test", "test_file", 0);
...@@ -86,13 +118,31 @@ TEST(ExceptionHolderTester, TestBadAllocCatchReplace) { ...@@ -86,13 +118,31 @@ TEST(ExceptionHolderTester, TestBadAllocCatchReplace) {
exception_holder.Catch(std::current_exception()); exception_holder.Catch(std::current_exception());
} }
ASSERT_TRUE(exception_holder.IsCaught()); ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "BadAlloc"); ASSERT_EQ(exception_holder.Type(), "EnforceNotMet");
try { try {
throw platform::EOFException("eof test", "test_file", 0); throw platform::EOFException("eof test", "test_file", 0);
} catch (...) { } catch (...) {
exception_holder.Catch(std::current_exception()); exception_holder.Catch(std::current_exception());
} }
ASSERT_EQ(exception_holder.Type(), "EnforceNotMet");
exception_holder.Clear();
try {
throw memory::allocation::BadAlloc("bad alloc test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "BadAlloc");
try {
throw platform::EnforceNotMet("enforce not met test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "BadAlloc"); ASSERT_EQ(exception_holder.Type(), "BadAlloc");
} }
......
...@@ -269,7 +269,14 @@ void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) { ...@@ -269,7 +269,14 @@ void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
void FastThreadedSSAGraphExecutor::ExecutionFinal( void FastThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) { std::vector<OpHandleBase *> *fetch_ops) {
VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it"; VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it";
ClearFetchOp(graph_, fetch_ops); // NOTE: If a new exception occurs in this ClearFetchOp operation, it will
// cause the loss of exception triggered firstly not thrown.
// Instead, the cleanup operation should only be performed when an EOF
// exception is caught. If other exceptions are triggered, the ClearFetchOp
// should not be continued.
if (exception_.Type() == "EOF") {
ClearFetchOp(graph_, fetch_ops);
}
exception_.ReThrow(); exception_.ReThrow();
} }
......
...@@ -36,7 +36,7 @@ OpHandleBase::~OpHandleBase() PADDLE_MAY_THROW { ...@@ -36,7 +36,7 @@ OpHandleBase::~OpHandleBase() PADDLE_MAY_THROW {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
for (auto &ev : events_) { for (auto &ev : events_) {
if (ev.second) { if (ev.second) {
PADDLE_ENFORCE(cudaEventDestroy(ev.second)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(ev.second));
} }
} }
#endif #endif
......
...@@ -111,6 +111,7 @@ void DeviceWorker::DumpParam(const Scope& scope, const int batch_id) { ...@@ -111,6 +111,7 @@ void DeviceWorker::DumpParam(const Scope& scope, const int batch_id) {
writer_ << os.str(); writer_ << os.str();
} }
} }
void DeviceWorker::InitRandomDumpConfig(const TrainerDesc& desc) { void DeviceWorker::InitRandomDumpConfig(const TrainerDesc& desc) {
bool enable_random_dump = desc.enable_random_dump(); bool enable_random_dump = desc.enable_random_dump();
if (!enable_random_dump) { if (!enable_random_dump) {
......
...@@ -335,6 +335,7 @@ class SectionWorker : public DeviceWorker { ...@@ -335,6 +335,7 @@ class SectionWorker : public DeviceWorker {
void SetSkipVars(const std::vector<std::string>& skip_vars) { void SetSkipVars(const std::vector<std::string>& skip_vars) {
skip_vars_ = skip_vars; skip_vars_ = skip_vars;
} }
static void ResetBatchId() { batch_id_ = 0; }
static std::atomic<int> cpu_id_; static std::atomic<int> cpu_id_;
......
...@@ -99,7 +99,7 @@ void DistMultiTrainer::InitTrainerEnv(const ProgramDesc &main_program, ...@@ -99,7 +99,7 @@ void DistMultiTrainer::InitTrainerEnv(const ProgramDesc &main_program,
} }
void DistMultiTrainer::InitOtherEnv(const ProgramDesc &main_program) { void DistMultiTrainer::InitOtherEnv(const ProgramDesc &main_program) {
if (need_dump_field_) { if (need_dump_field_ || need_dump_param_) {
InitDumpEnv(); InitDumpEnv();
} }
pull_dense_worker_->SetRootScope(root_scope_); pull_dense_worker_->SetRootScope(root_scope_);
...@@ -158,7 +158,7 @@ void DistMultiTrainer::Finalize() { ...@@ -158,7 +158,7 @@ void DistMultiTrainer::Finalize() {
} }
} }
if (need_dump_field_) { if (need_dump_field_ || need_dump_param_) {
FinalizeDumpEnv(); FinalizeDumpEnv();
} }
pull_dense_worker_->Stop(); pull_dense_worker_->Stop();
......
...@@ -49,7 +49,12 @@ TEST(DisMultiTrainerTest, test1) { ...@@ -49,7 +49,12 @@ TEST(DisMultiTrainerTest, test1) {
dataset->SetTrainerNum(1); dataset->SetTrainerNum(1);
dataset->SetDataFeedDesc(str); dataset->SetDataFeedDesc(str);
dataset->CreateReaders(); dataset->CreateReaders();
Scope root_scope;
tmp1->SetScope(&root_scope);
tmp1->Initialize(t, dataset.get()); tmp1->Initialize(t, dataset.get());
ProgramDesc p;
tmp1->InitOtherEnv(p);
tmp1->Finalize();
#endif #endif
} }
} // namespace framework } // namespace framework
......
...@@ -22,56 +22,104 @@ enum Mode { ...@@ -22,56 +22,104 @@ enum Mode {
HETER = 4; // support XPU and GPU computing server HETER = 4; // support XPU and GPU computing server
} }
message DistributedStrategy { message RecomputeConfig { repeated string checkpoints = 1; }
optional Mode mode = 1 [ default = COLLECTIVE ]; // just for serialization
// collective training strategy message AMPConfig {
optional bool amp = 2 [ default = false ]; optional float init_loss_scaling = 1 [ default = 32768.0 ];
optional int32 amp_loss_scaling = 3 [ default = 32768 ]; optional int32 incr_every_n_steps = 2 [ default = 1000 ];
optional bool recompute = 4 [ default = false ]; optional int32 decr_every_n_nan_or_inf = 3 [ default = 2 ];
repeated string recompute_checkpoints = 5; optional float incr_ratio = 4 [ default = 2.0 ];
optional bool localsgd = 6 [ default = false ]; optional float decr_ratio = 5 [ default = 0.8 ];
optional int32 localsgd_k_step = 7 [ default = 4 ]; optional bool use_dynamic_loss_scaling = 6 [ default = true ];
optional bool dgc = 8 [ default = false ]; repeated string custom_white_list = 7;
optional bool hierachical_allreduce = 9 [ default = false ]; repeated string custom_black_list = 8;
optional int32 nccl_comm_num = 10 [ default = 1 ]; repeated string custom_black_varnames = 9;
optional bool gradient_merge = 11 [ default = false ]; }
optional int32 gradient_merge_k_step = 12 [ default = 1 ];
optional bool sequential_execution = 13 [ default = false ]; message LocalSGDConfig { optional int32 k_steps = 1 [ default = 4 ]; }
optional bool enable_backward_optimizer_op_deps = 14 [ default = true ];
optional bool lars = 15 [ default = false ]; message GradientMergeConfig {
optional bool lamb = 16 [ default = false ]; optional int32 k_steps = 1 [ default = 1 ];
optional bool fuse_elewise_add_act_ops = 17 [ default = false ]; optional bool avg = 2 [ default = true ];
optional bool fuse_bn_act_ops = 18 [ default = false ]; }
optional bool enable_auto_fusion = 19 [ default = false ];
optional bool fuse_relu_depthwise_conv = 20 [ default = false ]; message LarsConfig {
optional bool enable_inplace = 21 [ default = false ]; optional float lars_coeff = 1 [ default = 0.001 ];
optional bool fuse_all_reduce_ops = 22 [ default = false ]; optional float lars_weight_decay = 2 [ default = 0.0005 ];
optional int32 num_iteration_per_drop_scope = 23 [ default = 1 ]; }
optional bool sync_batch_norm = 24 [ default = false ];
optional bool fuse_all_optimizer_ops = 25 [ default = false ];
// pipeline training message LambConfig {
optional bool pipeline = 101 [ default = false ]; optional float beta1 = 1 [ default = 0.001 ];
optional int32 pipeline_micro_batch = 102; optional float beta2 = 2 [ default = 0.999 ];
optional float epsilon = 3 [ default = 0.000001 ];
}
// parameter server training message BuildStrategy {
optional bool sync = 201 [ default = false ]; optional bool enable_sequential_execution = 1 [ default = false ];
optional bool async = 202 [ default = true ]; optional bool fuse_elewise_add_act_ops = 2 [ default = false ];
optional int32 async_k_step = 203 [ default = -1 ]; optional bool fuse_bn_act_ops = 3 [ default = false ];
optional int32 max_merge_var_num = 204 [ default = 1 ]; optional bool fuse_relu_depthwise_conv = 4 [ default = false ];
optional int32 send_queue_size = 205 [ default = 16 ]; optional bool fuse_broadcast_ops = 5 [ default = false ];
optional bool independent_recv_thread = 206 [ default = false ]; optional bool fuse_all_optimizer_ops = 6 [ default = false ];
optional int32 min_send_grad_num_before_recv = 207 [ default = 1 ]; optional bool enable_inplace = 7 [ default = false ];
optional int32 thread_pool_size = 208 [ default = 1 ]; optional bool enable_backward_optimizer_op_deps = 8 [ default = true ];
optional int32 send_wait_times = 209 [ default = 1 ]; optional bool cache_runtime_context = 9 [ default = false ];
optional bool runtime_split_send_recv = 210 [ default = false ]; }
optional bool use_thread_barrier = 211 [ default = false ];
// elastic deep learning strategies message ExecutionStrategy {
optional bool elastic = 301 [ default = false ]; optional int32 num_threads = 1 [ default = 1 ];
optional int32 num_iteration_per_drop_scope = 2 [ default = 10 ];
optional int32 num_iteration_per_run = 3 [ default = 1 ];
optional bool use_thread_barrier = 4 [ default = false ];
}
message AsyncConfig {
optional int32 k_steps = 1 [ default = 1 ];
optional int32 max_merge_var_num = 2 [ default = 1 ];
optional int32 send_queue_size = 3 [ default = 16 ];
optional bool independent_recv_thread = 4 [ default = false ];
optional int32 min_send_grad_num_before_recv = 5 [ default = 1 ];
optional int32 thread_pool_size = 6 [ default = 1 ];
optional int32 send_wait_times = 7 [ default = 1 ];
optional bool runtime_split_send_recv = 8 [ default = false ];
}
message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; }
message DistributedStrategy {
// bool options
optional Mode mode = 1 [ default = COLLECTIVE ];
optional bool amp = 2 [ default = false ];
optional bool recompute = 3 [ default = false ];
optional bool localsgd = 4 [ default = false ];
optional bool dgc = 5 [ default = false ];
optional bool gradient_merge = 6 [ default = false ];
optional bool lars = 7 [ default = false ];
optional bool lamb = 8 [ default = false ];
optional bool pipeline = 9 [ default = false ];
optional bool elastic = 10 [ default = false ];
optional bool auto = 11 [ default = false ];
optional bool a_sync = 12 [ default = true ];
optional bool sync_nccl_allreduce = 13 [ default = true ];
optional int32 nccl_comm_num = 14 [ default = 1 ];
optional bool use_hierarchical_allreduce = 15 [ default = false ];
optional int32 hierarchical_allreduce_inter_nranks = 16 [ default = 1 ];
optional bool sync_batch_norm = 17 [ default = false ];
optional bool fuse_all_reduce_ops = 18 [ default = true ];
optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ];
optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ];
// optional bool enable_backward_optimizer_op_deps = 19 [ default = true ];
// auto parallel optional RecomputeConfig recompute_configs = 101;
optional bool auto = 401 [ default = false ]; optional AMPConfig amp_configs = 102;
optional LocalSGDConfig localsgd_configs = 103;
optional GradientMergeConfig gradient_merge_configs = 104;
optional PipelineConfig pipeline_configs = 106;
optional AsyncConfig a_sync_configs = 107;
optional LarsConfig lars_configs = 108;
optional LambConfig lamb_configs = 109;
optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202;
} }
message DistributedJobInfo { message DistributedJobInfo {
......
...@@ -30,7 +30,10 @@ static ::DLDataType GetDLDataTypeCode() { ...@@ -30,7 +30,10 @@ static ::DLDataType GetDLDataTypeCode() {
} else if (std::is_integral<T>::value) { } else if (std::is_integral<T>::value) {
dtype.code = kDLInt; dtype.code = kDLInt;
} else { } else {
PADDLE_THROW("Unsupported data type %s", typeid(T).name()); PADDLE_THROW(platform::errors::Unavailable(
"Unsupported data type (%s), only supports float16, float, unsigned "
"int and int.",
platform::demangle(typeid(T).name())));
} }
dtype.bits = 8 * sizeof(T); dtype.bits = 8 * sizeof(T);
dtype.lanes = 1; dtype.lanes = 1;
...@@ -52,8 +55,9 @@ static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) { ...@@ -52,8 +55,9 @@ static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
static auto type_to_dtype_map = CreateDLDataTypeMap(); static auto type_to_dtype_map = CreateDLDataTypeMap();
static auto type_to_dtype_map_end_it = type_to_dtype_map.end(); static auto type_to_dtype_map_end_it = type_to_dtype_map.end();
auto it = type_to_dtype_map.find(static_cast<int>(type)); auto it = type_to_dtype_map.find(static_cast<int>(type));
PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %d", PADDLE_ENFORCE_NE(it, type_to_dtype_map_end_it,
type); platform::errors::InvalidArgument(
"Unsupported data type (%s).", DataTypeToString(type)));
return it->second; return it->second;
#undef REG_DL_DATA_TYPE #undef REG_DL_DATA_TYPE
} }
...@@ -73,7 +77,8 @@ struct DLContextVisitor : public boost::static_visitor<::DLContext> { ...@@ -73,7 +77,8 @@ struct DLContextVisitor : public boost::static_visitor<::DLContext> {
ctx.device_id = place.device; ctx.device_id = place.device;
return ctx; return ctx;
#else #else
PADDLE_THROW("platform::CUDAPlace is not supported in CPU only version"); PADDLE_THROW(platform::errors::Unavailable(
"platform::CUDAPlace is not supported in CPU only version."));
#endif #endif
} }
...@@ -84,8 +89,8 @@ struct DLContextVisitor : public boost::static_visitor<::DLContext> { ...@@ -84,8 +89,8 @@ struct DLContextVisitor : public boost::static_visitor<::DLContext> {
ctx.device_id = 0; ctx.device_id = 0;
return ctx; return ctx;
#else #else
PADDLE_THROW( PADDLE_THROW(platform::errors::Unavailable(
"platform::CUDAPinnedPlace is not supported in CPU only version"); "platform::CUDAPinnedPlace is not supported in CPU only version."));
#endif #endif
} }
}; };
...@@ -136,7 +141,10 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) { ...@@ -136,7 +141,10 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) {
// refer to cupy and cudf, the compact tensor first dim's strides need to be 1 // refer to cupy and cudf, the compact tensor first dim's strides need to be 1
// and second dim's strides need to be length of rows of cudf // and second dim's strides need to be length of rows of cudf
// cudf now only support dim=2 // cudf now only support dim=2
PADDLE_ENFORCE_LE(t_.ndim, 2, "cudf now only support dim=2."); PADDLE_ENFORCE_LE(t_.ndim, 2, platform::errors::InvalidArgument(
"cudf now only supports dimension is 2, "
"but received dimension is %d.",
t_.ndim));
if (t_.ndim > 1) if (t_.ndim > 1)
t_.strides = new int64_t[2]{1, t_.shape[1]}; t_.strides = new int64_t[2]{1, t_.shape[1]};
......
...@@ -556,9 +556,11 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -556,9 +556,11 @@ void DownpourWorker::TrainFilesWithProfiler() {
continue; continue;
} }
PADDLE_ENFORCE_EQ(framework::TensorContainsInf(*tensor), false, PADDLE_ENFORCE_EQ(framework::TensorContainsInf(*tensor), false,
"Tensor %s contains Inf", var_name); platform::errors::InvalidArgument(
"Tensor %s contains Inf.", var_name));
PADDLE_ENFORCE_EQ(framework::TensorContainsNAN(*tensor), false, PADDLE_ENFORCE_EQ(framework::TensorContainsNAN(*tensor), false,
"Tensor %s contains NAN", var_name); platform::errors::InvalidArgument(
"Tensor %s contains NAN.", var_name));
} }
if (need_to_push_sparse_) { if (need_to_push_sparse_) {
...@@ -829,9 +831,11 @@ void DownpourWorker::TrainFiles() { ...@@ -829,9 +831,11 @@ void DownpourWorker::TrainFiles() {
continue; continue;
} }
PADDLE_ENFORCE_EQ(framework::TensorContainsInf(*tensor), false, PADDLE_ENFORCE_EQ(framework::TensorContainsInf(*tensor), false,
"Tensor %s contains Inf", var_name); platform::errors::InvalidArgument(
"Tensor %s contains Inf.", var_name));
PADDLE_ENFORCE_EQ(framework::TensorContainsNAN(*tensor), false, PADDLE_ENFORCE_EQ(framework::TensorContainsNAN(*tensor), false,
"Tensor %s contains NAN", var_name); platform::errors::InvalidArgument(
"Tensor %s contains NAN.", var_name));
} }
if (need_to_push_sparse_) { if (need_to_push_sparse_) {
......
...@@ -26,7 +26,11 @@ struct EigenDim { ...@@ -26,7 +26,11 @@ struct EigenDim {
using Type = Eigen::DSizes<Eigen::DenseIndex, D>; using Type = Eigen::DSizes<Eigen::DenseIndex, D>;
static Type From(const DDim& dims) { static Type From(const DDim& dims) {
PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)"); PADDLE_ENFORCE_EQ(arity(dims), D,
platform::errors::InvalidArgument(
"Input dimension size should be equal to %d, but "
"received dimension size is %d.",
arity(dims), D));
Type ret; Type ret;
for (int64_t d = 0; d < arity(dims); d++) { for (int64_t d = 0; d < arity(dims); d++) {
ret[d] = dims[d]; ret[d] = dims[d];
...@@ -69,8 +73,11 @@ struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> { ...@@ -69,8 +73,11 @@ struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {
static typename EigenMatrix::Type Reshape(Tensor& tensor, // NOLINT static typename EigenMatrix::Type Reshape(Tensor& tensor, // NOLINT
int num_col_dims) { int num_col_dims) {
int rank = tensor.dims_.size(); int rank = tensor.dims_.size();
PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank, PADDLE_ENFORCE_EQ((num_col_dims > 0 && num_col_dims < rank), true,
"`num_col_dims` must be between (0, rank_of_tensor)."); platform::errors::InvalidArgument(
"Input dimension number(num_col_dims) must be "
"between 0 and %d, but received number is %d.",
rank, num_col_dims));
return EigenMatrix::From(tensor, return EigenMatrix::From(tensor,
flatten_to_2d(tensor.dims(), num_col_dims)); flatten_to_2d(tensor.dims(), num_col_dims));
} }
...@@ -78,8 +85,11 @@ struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> { ...@@ -78,8 +85,11 @@ struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {
static typename EigenMatrix::ConstType Reshape(const Tensor& tensor, static typename EigenMatrix::ConstType Reshape(const Tensor& tensor,
int num_col_dims) { int num_col_dims) {
int rank = tensor.dims_.size(); int rank = tensor.dims_.size();
PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank, PADDLE_ENFORCE_EQ((num_col_dims > 0 && num_col_dims < rank), true,
"`num_col_dims` must be between (0, rank_of_tensor)."); platform::errors::InvalidArgument(
"Input dimension number(num_col_dims) must be "
"between 0 and %d, but received number is %d.",
rank, num_col_dims));
return EigenMatrix::From(tensor, return EigenMatrix::From(tensor,
flatten_to_2d(tensor.dims(), num_col_dims)); flatten_to_2d(tensor.dims(), num_col_dims));
} }
......
...@@ -37,9 +37,12 @@ limitations under the License. */ ...@@ -37,9 +37,12 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run"); DECLARE_bool(use_mkldnn);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -83,14 +86,7 @@ Executor::~Executor() { ...@@ -83,14 +86,7 @@ Executor::~Executor() {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache, // Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working // this is needed to have mkl-dnn unit tests working
if (platform::is_cpu_place(place_)) { ClearMKLDNNCache(place_);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place_);
dev_ctx->ResetBlobMap();
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW);
}
#endif #endif
} }
......
...@@ -175,8 +175,9 @@ void DeleteUnusedTensors( ...@@ -175,8 +175,9 @@ void DeleteUnusedTensors(
garbages.emplace_back(t.MoveMemoryHolder()); garbages.emplace_back(t.MoveMemoryHolder());
} }
} else { } else {
PADDLE_THROW("Type %s of %s is not supported eager deletion", PADDLE_THROW(platform::errors::Unimplemented(
framework::ToTypeName(var->Type()), var_name); "Type %s of variable %s is not supported eager deletion.",
framework::ToTypeName(var->Type()), var_name));
} }
} }
......
...@@ -23,9 +23,6 @@ ...@@ -23,9 +23,6 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <size_t EMBEDX_DIM, size_t EXPAND_EMBED_DIM> template <size_t EMBEDX_DIM, size_t EXPAND_EMBED_DIM>
__global__ void PullCopy( __global__ void PullCopy(
......
...@@ -79,15 +79,15 @@ StreamGarbageCollector::StreamGarbageCollector(const platform::CUDAPlace &place, ...@@ -79,15 +79,15 @@ StreamGarbageCollector::StreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size) size_t max_memory_size)
: GarbageCollector(place, max_memory_size) { : GarbageCollector(place, max_memory_size) {
platform::CUDADeviceGuard guard(place.device); platform::CUDADeviceGuard guard(place.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream_));
callback_manager_.reset(new platform::StreamCallbackManager(stream_)); callback_manager_.reset(new platform::StreamCallbackManager(stream_));
} }
StreamGarbageCollector::~StreamGarbageCollector() { StreamGarbageCollector::~StreamGarbageCollector() {
auto place = BOOST_GET_CONST(platform::CUDAPlace, this->dev_ctx_->GetPlace()); auto place = BOOST_GET_CONST(platform::CUDAPlace, this->dev_ctx_->GetPlace());
platform::CUDADeviceGuard guard(place.device); platform::CUDADeviceGuard guard(place.device);
PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaStreamDestroy(stream_)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(stream_));
} }
cudaStream_t StreamGarbageCollector::stream() const { return stream_; } cudaStream_t StreamGarbageCollector::stream() const { return stream_; }
......
...@@ -96,14 +96,14 @@ class GradOpDescMakerBase { ...@@ -96,14 +96,14 @@ class GradOpDescMakerBase {
if (!drop_empty_grad) { if (!drop_empty_grad) {
return ret_val; return ret_val;
} }
PADDLE_ENFORCE_LE(var_names.size(), 1UL, PADDLE_ENFORCE_LE(
"BUG from operator developer:" var_names.size(), 1UL,
" for input argument with a list of variables, " platform::errors::Unavailable(
" drop_empty_grad is not allowed because it makes" "BUG from operator developer:"
" the correspondence bewteen a variable and its gradient" " for input argument with a list of variables, "
" ambiguous." " drop_empty_grad is not allowed because it makes"
" Op type %s", " the correspondence bewteen a variable and its gradient"
fwd_op_.Type()); " ambiguous."));
std::vector<std::string> dropped_ret_val; std::vector<std::string> dropped_ret_val;
dropped_ret_val.reserve(ret_val.size()); dropped_ret_val.reserve(ret_val.size());
...@@ -157,7 +157,8 @@ class GradOpDescMakerBase { ...@@ -157,7 +157,8 @@ class GradOpDescMakerBase {
const Attribute& GetAttr(const std::string& name) const { const Attribute& GetAttr(const std::string& name) const {
auto& map = fwd_op_.GetAttrMap(); auto& map = fwd_op_.GetAttrMap();
auto it = map.find(name); auto it = map.find(name);
PADDLE_ENFORCE(it != map.end(), "Cannot find attribute %s", name); PADDLE_ENFORCE_NE(it, map.end(), platform::errors::NotFound(
"Cannot find attribute (%s).", name));
return it->second; return it->second;
} }
......
...@@ -53,7 +53,9 @@ void HogwildWorker::CreateThreadScope(const ProgramDesc &program) { ...@@ -53,7 +53,9 @@ void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
auto &block = program.Block(0); auto &block = program.Block(0);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
root_scope_, "root_scope should be set before creating thread scope"); root_scope_,
platform::errors::NotFound(
"Root scope should be set before creating thread scope."));
thread_scope_ = &root_scope_->NewScope(); thread_scope_ = &root_scope_->NewScope();
......
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include <fcntl.h> #include <fcntl.h>
#include <sys/stat.h> #include <sys/stat.h>
#ifdef _WIN32 #ifdef _WIN32
#ifndef NOMINMAX
#define NOMINMAX // msvc max/min macro conflict with std::min/max
#endif
#include <windows.h> #include <windows.h>
#else #else
#include <sys/syscall.h> #include <sys/syscall.h>
......
...@@ -4,7 +4,7 @@ file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeList ...@@ -4,7 +4,7 @@ file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeList
file(APPEND ${pass_file} "\#pragma once\n") file(APPEND ${pass_file} "\#pragma once\n")
file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n") file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
copy_if_different(${pass_file} ${pass_file_final} extern_glog) copy_if_different(${pass_file} ${pass_file_final})
add_subdirectory(fuse_optimizer_ops_pass) add_subdirectory(fuse_optimizer_ops_pass)
add_subdirectory(memory_optimize_pass) add_subdirectory(memory_optimize_pass)
......
...@@ -135,7 +135,9 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, ...@@ -135,7 +135,9 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op) { void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op) {
// Check parameters // Check parameters
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); PADDLE_ENFORCE_EQ(graph->Has(kParamScopeAttr), true,
platform::errors::InvalidArgument(
"Graph have no attribute: kParamScopeAttr."));
auto& scope = graph->Get<Scope>(kParamScopeAttr); auto& scope = graph->Get<Scope>(kParamScopeAttr);
// Create new parameters. // Create new parameters.
...@@ -193,7 +195,10 @@ void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op) { ...@@ -193,7 +195,10 @@ void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op) {
// reshape attention_bias // reshape attention_bias
auto* attention_bias_t = auto* attention_bias_t =
scope.FindVar(param.AttentionBias)->GetMutable<LoDTensor>(); scope.FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1); PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1,
platform::errors::InvalidArgument(
"Tensor attention bias dimension size(%d) must be 1.",
attention_bias_t->dims().size()));
attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]})); attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]}));
auto* attention_scalar_bias_t = auto* attention_scalar_bias_t =
...@@ -252,7 +257,10 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, ...@@ -252,7 +257,10 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
B_forget.data<float>(), B_input.data<float>(), B_output.data<float>(), B_forget.data<float>(), B_input.data<float>(), B_output.data<float>(),
B_cell.data<float>()}; B_cell.data<float>()};
PADDLE_ENFORCE_EQ(B_forget.dims().size(), 1); PADDLE_ENFORCE_EQ(B_forget.dims().size(), 1,
platform::errors::InvalidArgument(
"Tensor B forget dimension size(%d) must be 1.",
B_forget.dims().size()));
int D = B_forget.dims()[0]; int D = B_forget.dims()[0];
out->Resize(make_ddim({1, 4 * D})); out->Resize(make_ddim({1, 4 * D}));
auto* out_data = out->mutable_data<float>(platform::CPUPlace()); auto* out_data = out->mutable_data<float>(platform::CPUPlace());
......
...@@ -119,9 +119,11 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -119,9 +119,11 @@ class CoalesceGradTensorPass : public ir::Pass {
p_g_dense_grad.insert(p_g_dense_grad.end(), group_p_g.begin(), p_g_dense_grad.insert(p_g_dense_grad.end(), group_p_g.begin(),
group_p_g.end()); group_p_g.end());
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(p_g_dense_grad.size(), num_of_p_g_dense_grad,
p_g_dense_grad.size(), num_of_p_g_dense_grad, platform::errors::InvalidArgument(
"The number of p_g_dense_grad is not consistent with before."); "The number of dense grads is not consistent with "
"previous. Previous(%d), now(%d).",
p_g_dense_grad.size(), num_of_p_g_dense_grad));
auto &pinned_var_set = auto &pinned_var_set =
graph->GetOrInit<details::PinnedVars>(details::kPinnedVars); graph->GetOrInit<details::PinnedVars>(details::kPinnedVars);
...@@ -131,8 +133,11 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -131,8 +133,11 @@ class CoalesceGradTensorPass : public ir::Pass {
} else { } else {
for (auto &sub_param_grad : group_params_grads) { for (auto &sub_param_grad : group_params_grads) {
RecordGradients(p_g_dense_grad, vars_info, &pinned_var_set); RecordGradients(p_g_dense_grad, vars_info, &pinned_var_set);
PADDLE_ENFORCE_EQ(IsUnifiedDtype(sub_param_grad, vars_info), true, PADDLE_ENFORCE_EQ(
"The data type of the same group is not consistent."); IsUnifiedDtype(sub_param_grad, vars_info), true,
platform::errors::InvalidArgument("All gradient variable in "
"kGroupParamsAndDenseGrads, must "
"have same type."));
CoalesceTensors(vars_info, sub_param_grad, &result); CoalesceTensors(vars_info, sub_param_grad, &result);
} }
} }
...@@ -145,15 +150,25 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -145,15 +150,25 @@ class CoalesceGradTensorPass : public ir::Pass {
// The Gradients should not be reused during memory optimization. // The Gradients should not be reused during memory optimization.
for (auto &p_g : sub_param_grad) { for (auto &p_g : sub_param_grad) {
auto iter = vars_info.find(p_g.second); auto iter = vars_info.find(p_g.second);
PADDLE_ENFORCE_EQ(iter != vars_info.end(), true, "%s is not found.", PADDLE_ENFORCE_EQ(iter != vars_info.end(), true,
p_g.second); platform::errors::NotFound(
PADDLE_ENFORCE_EQ(!iter->second.empty(), true); "Parameter@Grad %s is not found.", p_g.second));
PADDLE_ENFORCE_EQ(
!iter->second.empty(), true,
platform::errors::InvalidArgument(
"Parameter@Grad %s's var node is empty.", p_g.second));
for (auto it : iter->second) { for (auto it : iter->second) {
PADDLE_ENFORCE_NOT_NULL(it->Var()); PADDLE_ENFORCE_NOT_NULL(
it->Var(),
platform::errors::InvalidArgument(
"A node of Parameter@Grad %s does not hold variable.",
p_g.second));
pinned_var_set->insert(it->Var()->Name()); pinned_var_set->insert(it->Var()->Name());
} }
PADDLE_ENFORCE_EQ(IsLoDTensorType(GetTypeOfVar(vars_info, p_g.second)), PADDLE_ENFORCE_EQ(IsLoDTensorType(GetTypeOfVar(vars_info, p_g.second)),
true); true,
platform::errors::InvalidArgument(
"Parameter@Grad %s is not LoDTensor.", p_g.second));
} }
} }
...@@ -192,8 +207,10 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -192,8 +207,10 @@ class CoalesceGradTensorPass : public ir::Pass {
auto fused_grad_var_name = std::string(details::kFusedVarNamePrefix) + auto fused_grad_var_name = std::string(details::kFusedVarNamePrefix) +
"@GRAD@" + params_grads.begin()->second; "@GRAD@" + params_grads.begin()->second;
auto &fused_var_set = result->Get<details::FusedVars>(details::kFusedVars); auto &fused_var_set = result->Get<details::FusedVars>(details::kFusedVars);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_grad_var_name), 0, PADDLE_ENFORCE_EQ(
"%s is duplicate in FusedVars.", fused_grad_var_name); fused_var_set.count(fused_grad_var_name), 0,
platform::errors::AlreadyExists("Var(%s) is duplicate in FusedVars.",
fused_grad_var_name));
fused_var_set.insert(fused_grad_var_name); fused_var_set.insert(fused_grad_var_name);
result->Get<details::FusedGrads>(details::kFusedGrads) result->Get<details::FusedGrads>(details::kFusedGrads)
.emplace_back(fused_grad_var_name); .emplace_back(fused_grad_var_name);
...@@ -420,11 +437,16 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -420,11 +437,16 @@ class CoalesceGradTensorPass : public ir::Pass {
const std::unordered_map<std::string, std::vector<Node *>> &vars_info, const std::unordered_map<std::string, std::vector<Node *>> &vars_info,
const std::string &var_name) const { const std::string &var_name) const {
auto grad_iter = vars_info.find(var_name); auto grad_iter = vars_info.find(var_name);
PADDLE_ENFORCE_EQ(grad_iter != vars_info.end(), true, "%s is not found.", PADDLE_ENFORCE_EQ(
var_name); grad_iter != vars_info.end(), true,
PADDLE_ENFORCE_EQ(!grad_iter->second.empty(), true, "%s is not found.", platform::errors::NotFound("Variable %s is not found.", var_name));
var_name); PADDLE_ENFORCE_EQ(!grad_iter->second.empty(), true,
PADDLE_ENFORCE_NOT_NULL(grad_iter->second.front()->Var()); platform::errors::InvalidArgument(
"Variable %s's node is empty.", var_name));
PADDLE_ENFORCE_NOT_NULL(
grad_iter->second.front()->Var(),
platform::errors::InvalidArgument(
"A node of %s does not hold variable.", var_name));
return grad_iter->second.front()->Var(); return grad_iter->second.front()->Var();
} }
...@@ -464,7 +486,12 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -464,7 +486,12 @@ class CoalesceGradTensorPass : public ir::Pass {
params_name.emplace_back(p_g.first); params_name.emplace_back(p_g.first);
grads_name.emplace_back(p_g.second); grads_name.emplace_back(p_g.second);
auto next_dtype = GetDtypeOfVar(vars_info, p_g.second); auto next_dtype = GetDtypeOfVar(vars_info, p_g.second);
PADDLE_ENFORCE_EQ(next_dtype, dtype); PADDLE_ENFORCE_EQ(
next_dtype, dtype,
platform::errors::InvalidArgument(
"All Parameter@Grad should have same dtype, but "
"there are two different type: %s, %s.",
DataTypeToString(next_dtype), DataTypeToString(dtype)));
} }
result->Get<details::ProgramDescs>(details::kProgramDescs).emplace_back(); result->Get<details::ProgramDescs>(details::kProgramDescs).emplace_back();
......
...@@ -50,7 +50,12 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight, ...@@ -50,7 +50,12 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight,
Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>; Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
// Re-compute bias of conv2d from AffineChannel // Re-compute bias of conv2d from AffineChannel
PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(), ac_bias_tensor.dims()); PADDLE_ENFORCE_EQ(
eltwise_y_in_tensor->dims(), ac_bias_tensor.dims(),
platform::errors::InvalidArgument(
"Tensor elementwise y(%d) and activation bias(%d) must have same "
"dimension.",
eltwise_y_in_tensor->dims().size(), ac_bias_tensor.dims().size()));
auto* scale_tensor = scope->FindVar(ac_scale.Name())->GetMutable<LoDTensor>(); auto* scale_tensor = scope->FindVar(ac_scale.Name())->GetMutable<LoDTensor>();
...@@ -78,11 +83,13 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight, ...@@ -78,11 +83,13 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight,
} }
void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = auto* conv_input =
...@@ -152,11 +159,13 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -152,11 +159,13 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
} }
void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = auto* conv_input =
......
...@@ -61,7 +61,12 @@ void recompute_bias_and_weights(const Scope* scope, ...@@ -61,7 +61,12 @@ void recompute_bias_and_weights(const Scope* scope,
Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>; Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
// Re-compute bias of conv2d from BN // Re-compute bias of conv2d from BN
PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(), bn_bias_tensor.dims()); PADDLE_ENFORCE_EQ(
eltwise_y_in_tensor->dims(), bn_bias_tensor.dims(),
platform::errors::InvalidArgument("Tensor elementwise y(%d) and batch "
"norm bias(%d) must have same dims.",
eltwise_y_in_tensor->dims().size(),
bn_bias_tensor.dims().size()));
auto* scale_tensor = scope->FindVar(bn_scale.Name())->GetMutable<LoDTensor>(); auto* scale_tensor = scope->FindVar(bn_scale.Name())->GetMutable<LoDTensor>();
auto* variance_tensor = auto* variance_tensor =
...@@ -116,11 +121,13 @@ void recompute_bias_and_weights(const Scope* scope, ...@@ -116,11 +121,13 @@ void recompute_bias_and_weights(const Scope* scope,
} }
void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = auto* conv_input =
...@@ -186,11 +193,18 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -186,11 +193,18 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
if (has_bias && conv->Op()->Input("Bias").size() > 0) { if (has_bias && conv->Op()->Input("Bias").size() > 0) {
// reuse existing conv bias node // reuse existing conv bias node
auto conv_bias_names = conv->Op()->Input("Bias"); auto conv_bias_names = conv->Op()->Input("Bias");
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1UL); PADDLE_ENFORCE_EQ(
conv_bias_names.size(), 1UL,
platform::errors::InvalidArgument("Find input var Bais error."));
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]); auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>(); auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), PADDLE_ENFORCE_EQ(
eltwise_y_in_tensor->dims()); conv_bias_tensor->dims(), eltwise_y_in_tensor->dims(),
platform::errors::InvalidArgument(
"Tensor convolution bias(%d) and elementwise y(%d) "
"must have same dims.",
conv_bias_tensor->dims().size(),
eltwise_y_in_tensor->dims().size()));
auto eigen_conv_bias = EigenVector<float>::From(*conv_bias_tensor); auto eigen_conv_bias = EigenVector<float>::From(*conv_bias_tensor);
eigen_conv_bias += EigenVector<float>::From(*eltwise_y_in_tensor); eigen_conv_bias += EigenVector<float>::From(*eltwise_y_in_tensor);
...@@ -236,11 +250,13 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -236,11 +250,13 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
} }
void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = auto* conv_input =
......
...@@ -71,8 +71,16 @@ void TestMain(const std::string& conv_type) { ...@@ -71,8 +71,16 @@ void TestMain(const std::string& conv_type) {
int num_bn_nodes_after = GetNumOpNodes(graph, "batch_norm"); int num_bn_nodes_after = GetNumOpNodes(graph, "batch_norm");
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_bn_nodes_before, 1); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(num_bn_nodes_after, 0); num_bn_nodes_before, 1,
platform::errors::InvalidArgument(
"Before conv_bn_fuse_pass, number of batch norm op(%d) must be 1.",
num_bn_nodes_before));
PADDLE_ENFORCE_EQ(
num_bn_nodes_after, 0,
platform::errors::InvalidArgument(
"After conv_bn_fuse_pass, number of batch norm op(%d) must be 0.",
num_bn_nodes_after));
} }
TEST(ConvBNFusePass, conv2d) { TestMain("conv"); } TEST(ConvBNFusePass, conv2d) { TestMain("conv"); }
......
...@@ -91,7 +91,9 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -91,7 +91,9 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
auto* new_conv_op = graph->CreateOpNode(&new_op_desc); auto* new_conv_op = graph->CreateOpNode(&new_op_desc);
// Link inputs and outputs. // Link inputs and outputs.
PADDLE_ENFORCE(subgraph.count(x)); PADDLE_ENFORCE_NE(
subgraph.count(x), 0,
platform::errors::NotFound("Detector did not find input x of conv2d."));
auto* conv_in_node = subgraph.at(x); auto* conv_in_node = subgraph.at(x);
IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input
......
...@@ -78,7 +78,9 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -78,7 +78,9 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
auto* new_conv_op = graph->CreateOpNode(&new_op_desc); auto* new_conv_op = graph->CreateOpNode(&new_op_desc);
// Link inputs and outputs. // Link inputs and outputs.
PADDLE_ENFORCE(subgraph.count(x)); PADDLE_ENFORCE_NE(
subgraph.count(x), 0,
platform::errors::NotFound("Detector did not find input x of conv2d."));
auto* conv_in_node = subgraph.at(x); auto* conv_in_node = subgraph.at(x);
IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input
......
...@@ -66,7 +66,9 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -66,7 +66,9 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
auto* new_conv_op = graph->CreateOpNode(&new_op_desc); auto* new_conv_op = graph->CreateOpNode(&new_op_desc);
// Link inputs and outputs. // Link inputs and outputs.
PADDLE_ENFORCE(subgraph.count(x)); PADDLE_ENFORCE_NE(
subgraph.count(x), 0,
platform::errors::NotFound("Detector did not find input x of conv2d."));
auto* conv_in_node = subgraph.at(x); auto* conv_in_node = subgraph.at(x);
IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input
......
...@@ -64,17 +64,23 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -64,17 +64,23 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef SET_IN #undef SET_IN
// Multiply embeddings with Weights // Multiply embeddings with Weights
PADDLE_ENFORCE(scope); PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
const std::string& embeddings = patterns::UniqueKey("Embeddings"); const std::string& embeddings = patterns::UniqueKey("Embeddings");
auto* embeddings_var = scope->Var(embeddings); auto* embeddings_var = scope->Var(embeddings);
PADDLE_ENFORCE(embeddings_var); PADDLE_ENFORCE_NOT_NULL(
embeddings_var,
platform::errors::InvalidArgument(
"Embeddings variable's pointer cannot be nullptr."));
auto* embeddings_tensor = auto* embeddings_tensor =
embeddings_var->GetMutable<framework::LoDTensor>(); embeddings_var->GetMutable<framework::LoDTensor>();
// Get WeightX size: [single_embedding, fc_size] // Get WeightX size: [single_embedding, fc_size]
// and embedding size: [dict_size, single_embedding] // and embedding size: [dict_size, single_embedding]
// and create new size of embeddings eg. [dict_size , hidden_size] // and create new size of embeddings eg. [dict_size , hidden_size]
auto* embedding_var = scope->FindVar(W->Name()); auto* embedding_var = scope->FindVar(W->Name());
PADDLE_ENFORCE(embedding_var); PADDLE_ENFORCE_NOT_NULL(
embedding_var, platform::errors::InvalidArgument(
"Embedding variable's pointer cannot be nullptr."));
const auto& embedding_tensor = embedding_var->Get<framework::LoDTensor>(); const auto& embedding_tensor = embedding_var->Get<framework::LoDTensor>();
const auto& weightx_tensor = const auto& weightx_tensor =
...@@ -90,7 +96,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -90,7 +96,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
// Adding biases to GEMM result to be // Adding biases to GEMM result to be
auto* lstm_bias_var = scope->FindVar(bias->Name()); auto* lstm_bias_var = scope->FindVar(bias->Name());
PADDLE_ENFORCE(lstm_bias_var); PADDLE_ENFORCE_NOT_NULL(lstm_bias_var,
platform::errors::InvalidArgument(
"Lstm bias var ptr cannot be nullptr."));
const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>(); const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>();
auto alpha = 1.0f; auto alpha = 1.0f;
......
...@@ -56,8 +56,17 @@ TEST(FCElementwiseLayerNormFusePass, basic) { ...@@ -56,8 +56,17 @@ TEST(FCElementwiseLayerNormFusePass, basic) {
GetNumOpNodes(graph, "fused_fc_elementwise_layernorm"); GetNumOpNodes(graph, "fused_fc_elementwise_layernorm");
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 6); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1); num_nodes_before, num_nodes_after + 6,
platform::errors::InvalidArgument(
"After pass, the number of nodes should be reduced by 6, but the "
"number before pass is %d, after pass is %d.",
num_nodes_before, num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1,
platform::errors::InvalidArgument(
"After pass, the number of nodes of type "
"'fused_fc_elementwise_layernorm' should be 1, not %d.",
num_fused_nodes_after));
} }
} // namespace ir } // namespace ir
......
...@@ -25,7 +25,8 @@ namespace framework { ...@@ -25,7 +25,8 @@ namespace framework {
namespace ir { namespace ir {
void FCFusePass::ApplyImpl(ir::Graph* graph) const { void FCFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("fc_fuse", graph); FusePassBase::Init("fc_fuse", graph);
int found_fc_count = 0; int found_fc_count = 0;
......
...@@ -79,9 +79,17 @@ TEST(FCFusePass, basic) { ...@@ -79,9 +79,17 @@ TEST(FCFusePass, basic) {
int num_fc_nodes_after = GetNumOpNodes(graph, "fc"); int num_fc_nodes_after = GetNumOpNodes(graph, "fc");
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 6); PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 6,
PADDLE_ENFORCE_EQ(num_fc_nodes_after, 2); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(num_mul_nodes_before, num_fc_nodes_after); "num_nodes_before=%d, num_nodes_after=%d.",
num_nodes_before, num_nodes_after));
PADDLE_ENFORCE_EQ(num_fc_nodes_after, 2,
platform::errors::InvalidArgument("num_fc_nodes_after=%d.",
num_fc_nodes_after));
PADDLE_ENFORCE_EQ(num_mul_nodes_before, num_fc_nodes_after,
platform::errors::InvalidArgument(
"num_mul_nodes_before=%d, num_fc_nodes_after=%d.",
num_mul_nodes_before, num_fc_nodes_after));
} }
} // namespace ir } // namespace ir
......
...@@ -26,15 +26,15 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -26,15 +26,15 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::FC fc_pattern(pattern, name_scope);
patterns::GRU gru_pattern(pattern, name_scope);
PDNode* x = PDNode* x =
pattern->NewNode(patterns::UniqueKey("x"))->assert_var_not_persistable(); pattern->NewNode(patterns::UniqueKey("x"))->assert_var_not_persistable();
// Create pattern.
patterns::FC fc_pattern(pattern, name_scope);
auto* fc_out = fc_pattern(x, with_fc_bias, /* with_relu */ false); auto* fc_out = fc_pattern(x, with_fc_bias, /* with_relu */ false);
fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse. fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse.
patterns::GRU gru_pattern(pattern, name_scope);
gru_pattern(fc_out); gru_pattern(fc_out);
// Create New OpDesc // Create New OpDesc
...@@ -48,17 +48,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -48,17 +48,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
SET_IN(X, x); SET_IN(X, x);
SET_IN(WeightX, weight_x); SET_IN(WeightX, weight_x);
SET_IN(WeightH, weight_h); SET_IN(WeightH, weight_h);
if (with_fc_bias) { SET_IN(Bias, bias);
op_desc.SetInput("Bias", {NEW_NAME(bias) + bias->Name()});
} else {
SET_IN(Bias, bias);
}
#undef SET_IN #undef SET_IN
// TODO(grygielski): Add H0 to the pass
op_desc.SetInput("H0", {}); op_desc.SetInput("H0", {});
op_desc.SetOutput("Hidden", {hidden->Name()}); op_desc.SetOutput("Hidden", {hidden->Name()});
op_desc.SetAttr("is_reverse", gru->Op()->GetAttr("is_reverse")); op_desc.SetAttr("is_reverse", gru->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("origin_mode",
gru->Op()->GetAttrIfExists<bool>("origin_mode"));
// TODO(TJ): This should be a option for infer // TODO(TJ): This should be a option for infer
op_desc.SetAttr("use_seq", true); op_desc.SetAttr("use_seq", true);
op_desc.SetAttr("activation", gru->Op()->GetAttr("activation"));
op_desc.SetAttr("gate_activation", gru->Op()->GetAttr("gate_activation"));
#define SET_IMTERMEDIATE_OUT(key) op_desc.SetOutput(#key, {NEW_NAME(key)}) #define SET_IMTERMEDIATE_OUT(key) op_desc.SetOutput(#key, {NEW_NAME(key)})
SET_IMTERMEDIATE_OUT(ReorderedH0); SET_IMTERMEDIATE_OUT(ReorderedH0);
...@@ -68,26 +69,30 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -68,26 +69,30 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef SET_IMTERMEDIATE_OUT #undef SET_IMTERMEDIATE_OUT
auto* op = graph->CreateOpNode(&op_desc); auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto& scope = graph->Get<Scope>(kParamScopeAttr);
if (with_fc_bias) { if (with_fc_bias) {
// Fusion GRU bias = fcbias + grubias auto* gru_bias_var = scope->FindVar(bias->Name());
auto* fusion_bias_var = scope.Var(NEW_NAME(bias) + bias->Name()); auto* fc_bias_var = scope->FindVar(fc_bias->Name());
auto* out_bias_tensor = PADDLE_ENFORCE_NE(
fusion_bias_var->GetMutable<framework::LoDTensor>(); gru_bias_var, nullptr,
PADDLE_ENFORCE(fusion_bias_var); platform::errors::NotFound("GRU bias var has not been found."));
auto* gru_bias_var = scope.FindVar(bias->Name()); PADDLE_ENFORCE_NE(
auto* fc_bias_var = scope.FindVar(fc_bias->Name()); fc_bias_var, nullptr,
PADDLE_ENFORCE(gru_bias_var); platform::errors::NotFound("FC bias var has not been found."));
PADDLE_ENFORCE(fc_bias_var);
const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>(); auto* gru_bias_tensor = gru_bias_var->GetMutable<LoDTensor>();
const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>(); auto* fc_bias_tensor = fc_bias_var->GetMutable<LoDTensor>();
// new bias = fc bias + gru bias PADDLE_ENFORCE_EQ(
out_bias_tensor->Resize(gru_bias_tenosr.dims()); gru_bias_tensor->numel(), fc_bias_tensor->numel(),
auto* data = out_bias_tensor->mutable_data<float>(platform::CPUPlace()); platform::errors::PreconditionNotMet(
for (int i = 0; i < out_bias_tensor->numel(); i++) { "GRU and FC biases have to have equal number of elements."));
data[i] =
fc_bias_tensor.data<float>()[i] + gru_bias_tenosr.data<float>()[i]; auto gru_bias_data =
gru_bias_tensor->mutable_data<float>(platform::CPUPlace());
auto* fc_bias_data = fc_bias_tensor->data<float>();
// Recompute GRU bias
for (int i = 0; i < gru_bias_tensor->numel(); ++i) {
gru_bias_data[i] += fc_bias_data[i];
} }
} }
#undef GET_NODE #undef GET_NODE
...@@ -108,7 +113,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -108,7 +113,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
IR_NODE_LINK_TO(x, op); IR_NODE_LINK_TO(x, op);
IR_NODE_LINK_TO(weight_x, op); IR_NODE_LINK_TO(weight_x, op);
IR_NODE_LINK_TO(weight_h, op); IR_NODE_LINK_TO(weight_h, op);
IR_NODE_LINK_TO(bias, op); // actually should link to new bias if have IR_NODE_LINK_TO(bias, op);
IR_NODE_LINK_TO(op, hidden); IR_NODE_LINK_TO(op, hidden);
// h0? // h0?
return op; return op;
......
...@@ -52,13 +52,17 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -52,13 +52,17 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
#undef SET_IN #undef SET_IN
if (with_fc_bias) { if (with_fc_bias) {
// Add FC-bias with LSTM-bias and create a new weight // Add FC-bias with LSTM-bias and create a new weight
PADDLE_ENFORCE(scope); PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
const std::string& new_bias_var = patterns::UniqueKey("NewBias"); const std::string& new_bias_var = patterns::UniqueKey("NewBias");
auto* bias_var = scope->Var(new_bias_var); auto* bias_var = scope->Var(new_bias_var);
PADDLE_ENFORCE(bias_var); PADDLE_ENFORCE_NOT_NULL(bias_var, platform::errors::InvalidArgument(
"Bias var ptr cannot be nullptr."));
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>(); auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
auto* lstm_bias_var = scope->FindVar(bias->Name()); auto* lstm_bias_var = scope->FindVar(bias->Name());
PADDLE_ENFORCE(lstm_bias_var); PADDLE_ENFORCE_NOT_NULL(lstm_bias_var,
platform::errors::InvalidArgument(
"Lstm bias var ptr cannot be nullptr."));
const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>(); const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>();
bias_tensor->Resize(lstm_bias_tensor.dims()); bias_tensor->Resize(lstm_bias_tensor.dims());
......
...@@ -320,7 +320,7 @@ std::vector<Node *> FuseBatchNormActPass::ReplaceNode( ...@@ -320,7 +320,7 @@ std::vector<Node *> FuseBatchNormActPass::ReplaceNode(
return node; return node;
}); });
PADDLE_ENFORCE_EQ(has_replaced, true, PADDLE_ENFORCE_EQ(has_replaced, true,
platform::errors::NotFound("Not find %s in the node list.", platform::errors::NotFound("Not found %s in the node list.",
cur_node->Name())); cur_node->Name()));
return new_list; return new_list;
} }
......
...@@ -42,7 +42,8 @@ void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const { ...@@ -42,7 +42,8 @@ void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const {
// ele_add(x, act(y)) // ele_add(x, act(y))
ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct( ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const { ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("elewise_add_act", graph); FusePassBase::Init("elewise_add_act", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -93,7 +94,8 @@ ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct( ...@@ -93,7 +94,8 @@ ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct(
// act(ele_add(x,y)) // act(ele_add(x,y))
ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd( ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const { ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("act_elewise_add", graph); FusePassBase::Init("act_elewise_add", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -145,7 +147,8 @@ ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd( ...@@ -145,7 +147,8 @@ ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd(
// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"] // ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"]
ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad( ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const { ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("elewise_add_act_grad", graph); FusePassBase::Init("elewise_add_act_grad", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -252,10 +255,11 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { ...@@ -252,10 +255,11 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
bool save_intermediate_out = BOOST_GET_CONST( bool save_intermediate_out = BOOST_GET_CONST(
bool, cur_node->Op()->GetAttr("save_intermediate_out")); bool, cur_node->Op()->GetAttr("save_intermediate_out"));
auto intermediate_out_args = cur_node->Op()->Output("IntermediateOut"); auto intermediate_out_args = cur_node->Op()->Output("IntermediateOut");
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
save_intermediate_out && !intermediate_out_args.empty(), (save_intermediate_out && !intermediate_out_args.empty()), true,
"The %s should save the intermediate_out in the fusing stage.", platform::errors::InvalidArgument(
cur_node->Name()); "The %s should save the intermediate out in the fusing stage.",
cur_node->Name()));
// If the intermediate_out's output is empty, it should be removed. // If the intermediate_out's output is empty, it should be removed.
auto cur_node_outputs = cur_node->outputs; auto cur_node_outputs = cur_node->outputs;
...@@ -271,10 +275,11 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { ...@@ -271,10 +275,11 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
} else if (cur_node->Name() == "fused_elemwise_activation_grad") { } else if (cur_node->Name() == "fused_elemwise_activation_grad") {
auto intermediate_out_grad_args = auto intermediate_out_grad_args =
cur_node->Op()->Output(GradVarName("IntermediateOut")); cur_node->Op()->Output(GradVarName("IntermediateOut"));
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
!intermediate_out_grad_args.empty(), intermediate_out_grad_args.empty(), false,
"The %s should save the intermediate_out in the fusing stage.", platform::errors::InvalidArgument(
cur_node->Name()); "The %s should save the intermediate out in the fusing stage.",
cur_node->Name()));
auto cur_node_outputs = cur_node->outputs; auto cur_node_outputs = cur_node->outputs;
// If the intermediate_out_g's output is empty, it should be removed. // If the intermediate_out_g's output is empty, it should be removed.
for (auto &out : cur_node_outputs) { for (auto &out : cur_node_outputs) {
...@@ -312,7 +317,11 @@ void FuseElewiseAddActPass::ReLinkNodes(Graph *graph, ...@@ -312,7 +317,11 @@ void FuseElewiseAddActPass::ReLinkNodes(Graph *graph,
nodes2delete.emplace(out); nodes2delete.emplace(out);
} }
} else { } else {
PADDLE_ENFORCE(out == intermediate_out); PADDLE_ENFORCE_EQ(
out, intermediate_out,
platform::errors::InvalidArgument(
"Output of op(%s) must be %s, but not %s.", op_1->Name(),
intermediate_out->Name(), out->Name()));
IR_OP_VAR_LINK(fused_op, out); IR_OP_VAR_LINK(fused_op, out);
} }
} }
...@@ -347,8 +356,9 @@ std::vector<Node *> FuseElewiseAddActPass::ReplaceNode( ...@@ -347,8 +356,9 @@ std::vector<Node *> FuseElewiseAddActPass::ReplaceNode(
} }
return node; return node;
}); });
PADDLE_ENFORCE(has_replaced, "Not find %s in the node list.", PADDLE_ENFORCE_EQ(has_replaced, true,
cur_node->Name()); platform::errors::NotFound("Not found %s in the node list.",
cur_node->Name()));
return new_list; return new_list;
} }
......
...@@ -50,18 +50,25 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -50,18 +50,25 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
fused_scale2->inputs.end()); fused_scale2->inputs.end());
for (auto &out_node : fused_scale1->outputs) { for (auto &out_node : fused_scale1->outputs) {
if (fused_scale2_in_nodes.count(out_node)) { if (fused_scale2_in_nodes.count(out_node)) {
PADDLE_ENFORCE(out_node->IsCtrlVar(), PADDLE_ENFORCE_EQ(out_node->IsCtrlVar(), true,
"The dependency var only should be ctrl var."); platform::errors::PreconditionNotMet(
"In adam op pass, the dependency var(%s) only "
"should be ctrl var.",
out_node->Name()));
not_need_ctrl_var_nodes.insert(out_node); not_need_ctrl_var_nodes.insert(out_node);
} }
} }
for (auto &node : not_need_ctrl_var_nodes) { for (auto &node : not_need_ctrl_var_nodes) {
// remove this node from the input op node. // remove this node from the input op node.
PADDLE_ENFORCE(!node->inputs.empty(), PADDLE_ENFORCE_EQ(
"The input should not be empty here."); node->inputs.empty(), false,
platform::errors::PreconditionNotMet(
"Node(%s)'s input should not be empty here.", node->Name()));
auto op_node = node->inputs.front(); auto op_node = node->inputs.front();
PADDLE_ENFORCE(op_node->IsOp()); PADDLE_ENFORCE_EQ(op_node->IsOp(), true,
platform::errors::PreconditionNotMet(
"Node(%s) should be an OP node.", op_node->Name()));
op_node->outputs.erase( op_node->outputs.erase(
remove_if( remove_if(
op_node->outputs.begin(), op_node->outputs.end(), op_node->outputs.begin(), op_node->outputs.end(),
...@@ -85,7 +92,9 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -85,7 +92,9 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
const std::unordered_map<std::string, std::vector<std::string>> &vars_set, const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name, const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const { const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const {
PADDLE_ENFORCE_GT(adam_ops.size(), static_cast<size_t>(0)); PADDLE_ENFORCE_GT(
adam_ops.size(), static_cast<size_t>(0),
platform::errors::InvalidArgument("No adam op in the graph."));
// Check attributions // Check attributions
// NOTE: If new attribution is added, the following code maybe need change. // NOTE: If new attribution is added, the following code maybe need change.
...@@ -102,22 +111,58 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -102,22 +111,58 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
int64_t, adam_ops[0]->Op()->GetAttr("min_row_size_to_use_multithread")); int64_t, adam_ops[0]->Op()->GetAttr("min_row_size_to_use_multithread"));
for (auto &adam_op : adam_ops) { for (auto &adam_op : adam_ops) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
beta1, BOOST_GET_CONST(float, adam_op->Op()->GetAttr("beta1"))); beta1, BOOST_GET_CONST(float, adam_op->Op()->GetAttr("beta1")),
platform::errors::PreconditionNotMet(
"All adam Op's attr(beta1) must be same, but there are two "
"different "
"value: %f, %f.",
beta1, BOOST_GET_CONST(float, adam_op->Op()->GetAttr("beta1"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
beta2, BOOST_GET_CONST(float, adam_op->Op()->GetAttr("beta2"))); beta2, BOOST_GET_CONST(float, adam_op->Op()->GetAttr("beta2")),
platform::errors::PreconditionNotMet(
"All adam Op's attr(beta2) must be same, but there are two "
"different "
"value: %f, %f.",
beta2, BOOST_GET_CONST(float, adam_op->Op()->GetAttr("beta2"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
epsilon, BOOST_GET_CONST(float, adam_op->Op()->GetAttr("epsilon"))); epsilon, BOOST_GET_CONST(float, adam_op->Op()->GetAttr("epsilon")),
platform::errors::PreconditionNotMet(
"All adam Op's attr(epsilon) must be same, but there are two "
"different "
"value: %f, %f.",
epsilon,
BOOST_GET_CONST(float, adam_op->Op()->GetAttr("epsilon"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
lazy_mode, lazy_mode, BOOST_GET_CONST(bool, adam_op->Op()->GetAttr("lazy_mode")),
BOOST_GET_CONST(bool, adam_op->Op()->GetAttr("lazy_mode"))); platform::errors::PreconditionNotMet(
"All adam Op's attr(lazy_mode) must be same, but there are two "
"different "
"value: %d, %d.",
lazy_mode,
BOOST_GET_CONST(bool, adam_op->Op()->GetAttr("lazy_mode"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
min_row_size_to_use_multithread, min_row_size_to_use_multithread,
BOOST_GET_CONST(int64_t, adam_op->Op()->GetAttr( BOOST_GET_CONST(int64_t, adam_op->Op()->GetAttr(
"min_row_size_to_use_multithread"))); "min_row_size_to_use_multithread")),
platform::errors::PreconditionNotMet(
"All adam Op's attr(min_row_size_to_use_multithread) must be "
"same, but there are two different value: %I64, %I64.",
min_row_size_to_use_multithread,
BOOST_GET_CONST(
int64_t,
adam_op->Op()->GetAttr("min_row_size_to_use_multithread"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op_role, op_role,
BOOST_GET_CONST(int, adam_op->Op()->GetAttr( BOOST_GET_CONST(int, adam_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName()))); OpProtoAndCheckerMaker::OpRoleAttrName())),
platform::errors::PreconditionNotMet(
"All adam Op's attr(op_role) must be same, but there are two "
"different "
"value: %d, %d.",
op_role,
BOOST_GET_CONST(int,
adam_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName()))));
} }
// NOTE: fused_var is only exist in scope, so the graph doesn't have // NOTE: fused_var is only exist in scope, so the graph doesn't have
...@@ -154,7 +199,10 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -154,7 +199,10 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
const std::string &fused_var_name, const std::string &fused_var_name,
const std::vector<ir::Node *> &adam_ops, const std::vector<ir::Node *> &adam_ops,
ir::Graph *graph) const { ir::Graph *graph) const {
PADDLE_ENFORCE_EQ(beta_name.size(), adam_ops.size()); PADDLE_ENFORCE_EQ(beta_name.size(), adam_ops.size(),
platform::errors::InvalidArgument(
"Beta name size(%d) must equal to adam op size(%d).",
beta_name.size(), adam_ops.size()));
const std::string scale_op_name = "scale"; const std::string scale_op_name = "scale";
// Get the scale_ops of dealing the adam's beta var. // Get the scale_ops of dealing the adam's beta var.
...@@ -168,7 +216,9 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -168,7 +216,9 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
return var_node->Var() && return var_node->Var() &&
var_node->Var()->Name() == beta_1_pow_name; var_node->Var()->Name() == beta_1_pow_name;
}); });
PADDLE_ENFORCE(beta_pow_iter != adam_ops[i]->inputs.end()); PADDLE_ENFORCE_NE(beta_pow_iter, adam_ops[i]->inputs.end(),
platform::errors::NotFound(
"Can not find %s in adam ops.", beta_1_pow_name));
auto beta_pow_node = *beta_pow_iter; auto beta_pow_node = *beta_pow_iter;
auto scale_op_iter = std::find_if( auto scale_op_iter = std::find_if(
...@@ -176,11 +226,18 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -176,11 +226,18 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
[&scale_op_name](ir::Node *op_node) -> bool { [&scale_op_name](ir::Node *op_node) -> bool {
return op_node->Op() && op_node->Op()->Type() == scale_op_name; return op_node->Op() && op_node->Op()->Type() == scale_op_name;
}); });
PADDLE_ENFORCE(scale_op_iter != beta_pow_node->outputs.end()); PADDLE_ENFORCE_NE(
scale_op_iter, beta_pow_node->outputs.end(),
platform::errors::NotFound("Can not find %s in beta pow node.",
scale_op_name));
scale_ops.emplace_back(*scale_op_iter); scale_ops.emplace_back(*scale_op_iter);
} }
PADDLE_ENFORCE_EQ(scale_ops.size(), beta_name.size()); PADDLE_ENFORCE_EQ(
scale_ops.size(), beta_name.size(),
platform::errors::PreconditionNotMet(
"Beta name size(%d) must equal to scale ops size(%d).",
beta_name.size(), scale_ops.size()));
VLOG(6) << "The number of scale op is " << scale_ops.size() << "."; VLOG(6) << "The number of scale op is " << scale_ops.size() << ".";
// Check attributions // Check attributions
// NOTE: If new attribution is added, the following code maybe need change. // NOTE: If new attribution is added, the following code maybe need change.
...@@ -193,16 +250,40 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -193,16 +250,40 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
BOOST_GET_CONST(bool, scale_ops[0]->Op()->GetAttr("bias_after_scale")); BOOST_GET_CONST(bool, scale_ops[0]->Op()->GetAttr("bias_after_scale"));
for (auto &scale_op : scale_ops) { for (auto &scale_op : scale_ops) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale, BOOST_GET_CONST(float, scale_op->Op()->GetAttr("scale"))); scale, BOOST_GET_CONST(float, scale_op->Op()->GetAttr("scale")),
platform::errors::PreconditionNotMet(
"All scale Op's attr(scale) must be same, but there are two "
"different "
"value: %f, %f.",
scale, BOOST_GET_CONST(float, scale_op->Op()->GetAttr("scale"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
bias, BOOST_GET_CONST(float, scale_op->Op()->GetAttr("bias"))); bias, BOOST_GET_CONST(float, scale_op->Op()->GetAttr("bias")),
platform::errors::PreconditionNotMet(
"All scale Op's attr(bias) must be same, but there are two "
"different "
"value: %f, %f.",
bias, BOOST_GET_CONST(float, scale_op->Op()->GetAttr("bias"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
bias_after_scale, bias_after_scale,
BOOST_GET_CONST(bool, scale_op->Op()->GetAttr("bias_after_scale"))); BOOST_GET_CONST(bool, scale_op->Op()->GetAttr("bias_after_scale")),
platform::errors::PreconditionNotMet(
"All scale Op's attr(bias_after_scale) must be same, but there "
"are two different value: %d, %d.",
bias_after_scale,
BOOST_GET_CONST(bool,
scale_op->Op()->GetAttr("bias_after_scale"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op_role, op_role,
BOOST_GET_CONST(int, scale_op->Op()->GetAttr( BOOST_GET_CONST(int, scale_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName()))); OpProtoAndCheckerMaker::OpRoleAttrName())),
platform::errors::PreconditionNotMet(
"All scale Op's attr(op_role) must be same, but there are two "
"different "
"value: %d, %d.",
op_role,
BOOST_GET_CONST(int,
scale_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName()))));
} }
// NOTE: fused_var is only exist in scope, so the graph doesn't have // NOTE: fused_var is only exist in scope, so the graph doesn't have
......
...@@ -37,7 +37,9 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass { ...@@ -37,7 +37,9 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
const std::unordered_map<std::string, std::vector<std::string>> &vars_set, const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name, const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &momentum_ops, ir::Graph *graph) const { const std::vector<ir::Node *> &momentum_ops, ir::Graph *graph) const {
PADDLE_ENFORCE_GT(momentum_ops.size(), static_cast<size_t>(0)); PADDLE_ENFORCE_GT(
momentum_ops.size(), static_cast<size_t>(0),
platform::errors::InvalidArgument("Momentum ops must not be empyt."));
// Check attributions // Check attributions
// NOTE: If new attribution is added, the following code maybe need change. // NOTE: If new attribution is added, the following code maybe need change.
...@@ -50,14 +52,32 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass { ...@@ -50,14 +52,32 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
for (auto &momentum_op : momentum_ops) { for (auto &momentum_op : momentum_ops) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
mu, BOOST_GET_CONST(float, momentum_op->Op()->GetAttr("mu"))); mu, BOOST_GET_CONST(float, momentum_op->Op()->GetAttr("mu")),
platform::errors::InvalidArgument(
"All momentum Op's attr(mu) must be same, but there are two "
"different "
"value: %f, %f.",
mu, BOOST_GET_CONST(float, momentum_op->Op()->GetAttr("mu"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
use_nesterov, use_nesterov,
BOOST_GET_CONST(bool, momentum_op->Op()->GetAttr("use_nesterov"))); BOOST_GET_CONST(bool, momentum_op->Op()->GetAttr("use_nesterov")),
platform::errors::InvalidArgument(
"All momentum Op's attr(use_nesterov) must be same, but there "
"are two different value: %d, %d.",
use_nesterov, BOOST_GET_CONST(bool, momentum_op->Op()->GetAttr(
"use_nesterov"))));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op_role, op_role,
BOOST_GET_CONST(int, momentum_op->Op()->GetAttr( BOOST_GET_CONST(int, momentum_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName()))); OpProtoAndCheckerMaker::OpRoleAttrName())),
platform::errors::InvalidArgument(
"All momentum Op's attr(op_role) must be same, but there are two "
"different "
"value: %d, %d.",
op_role,
BOOST_GET_CONST(int,
momentum_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName()))));
} }
// NOTE: fused_var is only exist in scope, so the graph doesn't have // NOTE: fused_var is only exist in scope, so the graph doesn't have
......
...@@ -41,10 +41,12 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -41,10 +41,12 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
for (auto &node : topo_nodes) { for (auto &node : topo_nodes) {
if (node->Op()->Type() == fuse_op_type) { if (node->Op()->Type() == fuse_op_type) {
auto grad_name = node->Op()->Input(kGrad); auto grad_name = node->Op()->Input(kGrad);
PADDLE_ENFORCE_EQ(grad_name.size(), static_cast<size_t>(1), PADDLE_ENFORCE_EQ(
"The %s operator has multiple gradient input. Expected " grad_name.size(), static_cast<size_t>(1),
"it to only have one gradient input.", platform::errors::InvalidArgument(
fuse_op_type); "The %s operator has multiple gradient input. Expected "
"it to only have one gradient input.",
fuse_op_type));
if (IsLoDTensorType(GetTypeOfVar(vars_info, grad_name[0]))) { if (IsLoDTensorType(GetTypeOfVar(vars_info, grad_name[0]))) {
opt_nodes.emplace_back(node); opt_nodes.emplace_back(node);
} }
...@@ -96,7 +98,8 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -96,7 +98,8 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
VLOG(6) << var_name << ": " << fused_var_name; VLOG(6) << var_name << ": " << fused_var_name;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
fused_var_set.count(fused_var_name), 0, fused_var_set.count(fused_var_name), 0,
platform::errors::AlreadyExists("The fused variable already exists.")); platform::errors::AlreadyExists(
"The fused variable(%s) already exists.", fused_var_name));
fused_var_set.insert(fused_var_name); fused_var_set.insert(fused_var_name);
fused_vars_name.emplace(var_name, fused_var_name); fused_vars_name.emplace(var_name, fused_var_name);
} }
...@@ -110,7 +113,10 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -110,7 +113,10 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
result.Get<details::ParamsAndGrads>(details::kParamsAndDenseGrads); result.Get<details::ParamsAndGrads>(details::kParamsAndDenseGrads);
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
params_and_dense_grads.size(), aux_var_map.at(kGrad).size(), params_and_dense_grads.size(), aux_var_map.at(kGrad).size(),
"The number of dense gradients should be little than optimizer ops."); platform::errors::InvalidArgument(
"The number of dense gradients(%d) should be "
"little than optimizer ops(%d).",
params_and_dense_grads.size(), aux_var_map.at(kGrad).size()));
std::unordered_set<std::string> opt_grad_set(aux_var_map.at(kGrad).size()); std::unordered_set<std::string> opt_grad_set(aux_var_map.at(kGrad).size());
for (auto &p_g : params_and_dense_grads) { for (auto &p_g : params_and_dense_grads) {
...@@ -130,13 +136,14 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -130,13 +136,14 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
// some gradient's name maybe changed. // some gradient's name maybe changed.
if (new_grad_idx.size() == 0) { if (new_grad_idx.size() == 0) {
if (!result.Has(details::kFusedGrads)) { if (!result.Has(details::kFusedGrads)) {
PADDLE_THROW( PADDLE_THROW(platform::errors::PreconditionNotMet(
"The coalesce_grad_tensor_pass should " "The coalesce_grad_tensor_pass should "
"be called before this pass."); "be called before this pass."));
} }
auto &fused_grad = result.Get<details::FusedGrads>(details::kFusedGrads); auto &fused_grad = result.Get<details::FusedGrads>(details::kFusedGrads);
PADDLE_ENFORCE_NE(fused_grad.size(), 0, PADDLE_ENFORCE_NE(fused_grad.size(), 0,
"The fused gradient should not be empty."); platform::errors::NotFound(
"The fused gradient should not be empty."));
if (fused_grad.size() > 1) { if (fused_grad.size() > 1) {
// Note(chenweihang): Because the dtype of those gradients is not // Note(chenweihang): Because the dtype of those gradients is not
// unified,so the number of fused gradients is more than one, // unified,so the number of fused gradients is more than one,
...@@ -146,8 +153,9 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -146,8 +153,9 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
auto &fused_vars = result.Get<details::FusedVars>(details::kFusedVars); auto &fused_vars = result.Get<details::FusedVars>(details::kFusedVars);
auto iter = auto iter =
std::find(fused_vars.begin(), fused_vars.end(), fused_grad.front()); std::find(fused_vars.begin(), fused_vars.end(), fused_grad.front());
PADDLE_ENFORCE_EQ(iter != fused_vars.end(), true, PADDLE_ENFORCE_EQ(
"Not found the fused gradient variable."); iter != fused_vars.end(), true,
platform::errors::NotFound("Not found the fused gradient variable."));
fused_vars_name[kGrad] = fused_grad.front(); fused_vars_name[kGrad] = fused_grad.front();
// Sort the parameters and auxiliary variables according // Sort the parameters and auxiliary variables according
...@@ -334,16 +342,24 @@ void FuseOptimizerOpPass::FuseGradientsToContinuousSpace( ...@@ -334,16 +342,24 @@ void FuseOptimizerOpPass::FuseGradientsToContinuousSpace(
// The Gradients should not be reused during memory optimization. // The Gradients should not be reused during memory optimization.
for (auto &grad_var_name : grads) { for (auto &grad_var_name : grads) {
auto iter = vars_info.find(grad_var_name); auto iter = vars_info.find(grad_var_name);
PADDLE_ENFORCE_EQ(iter != vars_info.end(), true, PADDLE_ENFORCE_EQ(
"The gradient variable %s is not found.", grad_var_name); iter != vars_info.end(), true,
PADDLE_ENFORCE_EQ(!iter->second.empty(), true, platform::errors::NotFound("The gradient variable %s is not found.",
"The gradient var node %s is not found.", grad_var_name); grad_var_name));
PADDLE_ENFORCE_NOT_NULL(iter->second.front()->Var(), PADDLE_ENFORCE_EQ(
"The gradient var node is null."); !iter->second.empty(), true,
platform::errors::NotFound("The gradient var node %s is not found.",
grad_var_name));
PADDLE_ENFORCE_NOT_NULL(
iter->second.front()->Var(),
platform::errors::InvalidArgument("The gradient var(%s) node is null.",
grad_var_name));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
IsLoDTensorType(iter->second.front()->Var()->GetType()), true, IsLoDTensorType(iter->second.front()->Var()->GetType()), true,
"Currently the gradient type only should be LoDTensor when " platform::errors::InvalidArgument(
"fusing optimizer ops."); "Currently the gradient(%s) type only should be LoDTensor when "
"fusing optimizer ops.",
grad_var_name));
for (auto var : iter->second) { for (auto var : iter->second) {
pinned_var_set.insert(var->Var()->Name()); pinned_var_set.insert(var->Var()->Name());
} }
...@@ -382,11 +398,14 @@ const VarDesc *FuseOptimizerOpPass::GetVarDescFromVarsInfo( ...@@ -382,11 +398,14 @@ const VarDesc *FuseOptimizerOpPass::GetVarDescFromVarsInfo(
const std::string &var_name) const { const std::string &var_name) const {
auto grad_iter = vars_info.find(var_name); auto grad_iter = vars_info.find(var_name);
PADDLE_ENFORCE_EQ(grad_iter != vars_info.end(), true, PADDLE_ENFORCE_EQ(grad_iter != vars_info.end(), true,
"The gradient variable %s is not found.", var_name); platform::errors::NotFound(
"The gradient variable %s is not found.", var_name));
PADDLE_ENFORCE_EQ(!grad_iter->second.empty(), true, PADDLE_ENFORCE_EQ(!grad_iter->second.empty(), true,
"The gradient var node %s is not found.", var_name); platform::errors::NotFound(
"The gradient var node %s is not found.", var_name));
PADDLE_ENFORCE_NOT_NULL(grad_iter->second.front()->Var(), PADDLE_ENFORCE_NOT_NULL(grad_iter->second.front()->Var(),
"The gradient var node is null."); platform::errors::InvalidArgument(
"The gradient var(%s) node is null.", var_name));
return grad_iter->second.front()->Var(); return grad_iter->second.front()->Var();
} }
...@@ -428,8 +447,9 @@ void FuseOptimizerOpPass::SortParametersAndAuxVars( ...@@ -428,8 +447,9 @@ void FuseOptimizerOpPass::SortParametersAndAuxVars(
const std::vector<std::pair<std::string, std::string>> &params_grads, const std::vector<std::pair<std::string, std::string>> &params_grads,
std::unordered_map<std::string, std::vector<std::string>> *aux_var_map, std::unordered_map<std::string, std::vector<std::string>> *aux_var_map,
std::vector<ir::Node *> *ops) const { std::vector<ir::Node *> *ops) const {
PADDLE_ENFORCE_NE(aux_var_map->count(kGrad), static_cast<size_t>(0), PADDLE_ENFORCE_NE(
"The gradient variable doesn‘t exist."); aux_var_map->count(kGrad), static_cast<size_t>(0),
platform::errors::NotFound("The gradient variable doesn‘t exist."));
auto &grad_vec = aux_var_map->at(kGrad); auto &grad_vec = aux_var_map->at(kGrad);
std::vector<size_t> grad_sort_idx; std::vector<size_t> grad_sort_idx;
...@@ -437,8 +457,10 @@ void FuseOptimizerOpPass::SortParametersAndAuxVars( ...@@ -437,8 +457,10 @@ void FuseOptimizerOpPass::SortParametersAndAuxVars(
for (auto &p_g : params_grads) { for (auto &p_g : params_grads) {
auto iter = std::find(grad_vec.begin(), grad_vec.end(), p_g.second); auto iter = std::find(grad_vec.begin(), grad_vec.end(), p_g.second);
PADDLE_ENFORCE_EQ(iter != grad_vec.end(), true, PADDLE_ENFORCE_EQ(
"%s is not found in gradient vector", p_g.second); iter != grad_vec.end(), true,
platform::errors::NotFound(
"Parameter@Grad(%s) is not found in gradient vector.", p_g.second));
auto idx = std::distance(grad_vec.begin(), iter); auto idx = std::distance(grad_vec.begin(), iter);
grad_sort_idx.emplace_back(idx); grad_sort_idx.emplace_back(idx);
} }
...@@ -477,9 +499,10 @@ void FuseOptimizerOpPass::GetFusingVarNamesMap( ...@@ -477,9 +499,10 @@ void FuseOptimizerOpPass::GetFusingVarNamesMap(
for (auto &var_n : aux_vars_name) { for (auto &var_n : aux_vars_name) {
auto arg_names = node->Op()->Input(var_n); auto arg_names = node->Op()->Input(var_n);
PADDLE_ENFORCE_EQ(arg_names.size(), static_cast<size_t>(1), PADDLE_ENFORCE_EQ(arg_names.size(), static_cast<size_t>(1),
"The input variable of optimizer to be fused is " platform::errors::InvalidArgument(
"invalid. Excepted %s only has one %s input.", "The input variable of optimizer to be fused is "
node->Op()->Type(), var_n); "invalid. Excepted %s only has one %s input.",
node->Op()->Type(), var_n));
(*aux_args_name)[var_n].emplace_back(arg_names[0]); (*aux_args_name)[var_n].emplace_back(arg_names[0]);
} }
} }
...@@ -525,10 +548,14 @@ void FuseOptimizerOpPass::InsertInputAndOutputForFusedOpNode( ...@@ -525,10 +548,14 @@ void FuseOptimizerOpPass::InsertInputAndOutputForFusedOpNode(
auto deal_with_ctrl_vars = [&out_dep_vars, &not_useful_vars, auto deal_with_ctrl_vars = [&out_dep_vars, &not_useful_vars,
&fused_opt_node](ir::Node *ctr_var_node) { &fused_opt_node](ir::Node *ctr_var_node) {
PADDLE_ENFORCE_EQ(ctr_var_node->inputs.size(), 1, PADDLE_ENFORCE_EQ(ctr_var_node->inputs.size(), 1,
"The control var node has nultiple inputs."); platform::errors::InvalidArgument(
"The control var(%s) node has multiple inputs.",
ctr_var_node->Name()));
if (ctr_var_node->inputs.front() == fused_opt_node) { if (ctr_var_node->inputs.front() == fused_opt_node) {
PADDLE_ENFORCE_GT(ctr_var_node->outputs.size(), 0, PADDLE_ENFORCE_GT(
"The control var node has no output."); ctr_var_node->outputs.size(), 0,
platform::errors::InvalidArgument(
"The control var(%s) node has no output.", ctr_var_node->Name()));
auto output_ops = ctr_var_node->outputs; auto output_ops = ctr_var_node->outputs;
output_ops.erase(std::remove_if(output_ops.begin(), output_ops.end(), output_ops.erase(std::remove_if(output_ops.begin(), output_ops.end(),
[&fused_opt_node](const ir::Node *node) { [&fused_opt_node](const ir::Node *node) {
......
...@@ -35,7 +35,9 @@ class FuseSgdOpPass : public FuseOptimizerOpPass { ...@@ -35,7 +35,9 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
const std::unordered_map<std::string, std::vector<std::string>> &vars_set, const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name, const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const { const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const {
PADDLE_ENFORCE_GT(sgd_ops.size(), static_cast<size_t>(0)); PADDLE_ENFORCE_GT(
sgd_ops.size(), static_cast<size_t>(0),
platform::errors::InvalidArgument("SGD ops must not be empyt."));
// NOTE: fused_var is only exist in scope, so the graph doesn't have // NOTE: fused_var is only exist in scope, so the graph doesn't have
// fused_var node. // fused_var node.
......
...@@ -25,14 +25,19 @@ void FusePassBase::Init(const std::string& repr, Graph* graph) const { ...@@ -25,14 +25,19 @@ void FusePassBase::Init(const std::string& repr, Graph* graph) const {
} }
Scope* FusePassBase::param_scope() const { Scope* FusePassBase::param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr)); PADDLE_ENFORCE_EQ(graph_->Has(kParamScopeAttr), true,
platform::errors::InvalidArgument(
"Graph must have kParamScopeAttr attribute."));
auto& scope = graph_->Get<framework::Scope>(kParamScopeAttr); auto& scope = graph_->Get<framework::Scope>(kParamScopeAttr);
return &scope; return &scope;
} }
void FusePassBase::AddStatis(int count_of_fused) const { void FusePassBase::AddStatis(int count_of_fused) const {
PADDLE_ENFORCE(graph_); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(!repr_.empty()); graph_, platform::errors::InvalidArgument("Graph cannot be nullptr."));
PADDLE_ENFORCE_EQ(repr_.empty(), false,
platform::errors::InvalidArgument(
"Fuse pass must be initialized with a name."));
if (!graph_->Has(kFuseStatisAttr)) { if (!graph_->Has(kFuseStatisAttr)) {
graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>); graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>);
} }
......
...@@ -31,7 +31,8 @@ void FuseReluDepthwiseConvPass::ApplyImpl(ir::Graph *graph) const { ...@@ -31,7 +31,8 @@ void FuseReluDepthwiseConvPass::ApplyImpl(ir::Graph *graph) const {
ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv( ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
ir::Graph *graph, bool only_forward) const { ir::Graph *graph, bool only_forward) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
if (only_forward) if (only_forward)
FusePassBase::Init("relu_depthwise_conv_only_forward", graph); FusePassBase::Init("relu_depthwise_conv_only_forward", graph);
else else
...@@ -110,23 +111,45 @@ ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv( ...@@ -110,23 +111,45 @@ ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
xg_var = subgraph.at(xg)->Var(); xg_var = subgraph.at(xg)->Var();
} }
PADDLE_ENFORCE_EQ(layer_op->Input("Input").size(), 1UL); PADDLE_ENFORCE_EQ(layer_op->Input("Input").size(), 1UL,
PADDLE_ENFORCE_EQ(layer_op->Input("Input")[0], y_var->Name()); platform::errors::InvalidArgument(
"Op(%s)'s input size(%d) must be 1.",
layer_op->Type(), layer_op->Input("Input").size()));
PADDLE_ENFORCE_EQ(
layer_op->Input("Input")[0], y_var->Name(),
platform::errors::InvalidArgument(
"Op(%s)'s input name(%s) must be %s.", layer_op->Type(),
layer_op->Input("Input")[0], y_var->Name()));
layer_op->SetInput("Input", {x_var->Name()}); layer_op->SetInput("Input", {x_var->Name()});
subgraph.at(layer)->inputs.push_back(subgraph.at(x)); subgraph.at(layer)->inputs.push_back(subgraph.at(x));
subgraph.at(x)->outputs.push_back(subgraph.at(layer)); subgraph.at(x)->outputs.push_back(subgraph.at(layer));
VLOG(4) << "replace " << y_var->Name() << " -> " << x_var->Name(); VLOG(4) << "replace " << y_var->Name() << " -> " << x_var->Name();
if (!only_forward) { if (!only_forward) {
PADDLE_ENFORCE_EQ(layer_g_op->Input("Input").size(), 1UL); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(layer_g_op->Input("Input")[0], y_var->Name()); layer_g_op->Input("Input").size(), 1UL,
platform::errors::InvalidArgument(
"Op(%s)'s input size(%d) must be 1.", layer_g_op->Type(),
layer_g_op->Input("Input").size()));
PADDLE_ENFORCE_EQ(
layer_g_op->Input("Input")[0], y_var->Name(),
platform::errors::InvalidArgument(
"Op(%s)'s input name(%s) must be %s.", layer_g_op->Type(),
layer_g_op->Input("Input")[0], y_var->Name()));
layer_g_op->SetInput("Input", {x_var->Name()}); layer_g_op->SetInput("Input", {x_var->Name()});
subgraph.at(layer_g)->inputs.push_back(subgraph.at(x)); subgraph.at(layer_g)->inputs.push_back(subgraph.at(x));
subgraph.at(x)->outputs.push_back(subgraph.at(layer_g)); subgraph.at(x)->outputs.push_back(subgraph.at(layer_g));
PADDLE_ENFORCE_EQ(layer_g_op->Output(GradVarName("Input")).size(), 1UL); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(layer_g_op->Output(GradVarName("Input"))[0], layer_g_op->Output(GradVarName("Input")).size(), 1UL,
yg_var->Name()); platform::errors::InvalidArgument(
"Op(%s)'s input size(%d) must be 1.", layer_g_op->Type(),
layer_g_op->Output(GradVarName("Input")).size()));
PADDLE_ENFORCE_EQ(
layer_g_op->Output(GradVarName("Input"))[0], yg_var->Name(),
platform::errors::InvalidArgument(
"Op(%s)'s input name(%s) must be %s.", layer_g_op->Type(),
layer_g_op->Output(GradVarName("Input"))[0], yg_var->Name()));
layer_g_op->SetOutput(GradVarName("Input"), {xg_var->Name()}); layer_g_op->SetOutput(GradVarName("Input"), {xg_var->Name()});
subgraph.at(layer_g)->outputs.push_back(subgraph.at(xg)); subgraph.at(layer_g)->outputs.push_back(subgraph.at(xg));
subgraph.at(xg)->inputs.push_back(subgraph.at(layer_g)); subgraph.at(xg)->inputs.push_back(subgraph.at(layer_g));
......
...@@ -136,7 +136,9 @@ bool FindCircleSubGraph(const Graph &graph, ...@@ -136,7 +136,9 @@ bool FindCircleSubGraph(const Graph &graph,
std::vector<ir::Node *> TopologySortOperations(const Graph &graph) { std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp> std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
adj_list = BuildOperationAdjList(graph); adj_list = BuildOperationAdjList(graph);
PADDLE_ENFORCE(!HasCircleInternal(adj_list, nullptr)); PADDLE_ENFORCE_EQ(HasCircleInternal(adj_list, nullptr), false,
platform::errors::InvalidArgument(
"Generated graph shouldn't contain cycle."));
std::unordered_set<ir::Node *> visited; std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret; std::vector<ir::Node *> ret;
for (auto adj : adj_list) { for (auto adj : adj_list) {
...@@ -161,7 +163,11 @@ BuildOperationAdjList(const Graph &graph) { ...@@ -161,7 +163,11 @@ BuildOperationAdjList(const Graph &graph) {
} }
for (auto &var : n->inputs) { for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) { for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); PADDLE_ENFORCE_EQ(
adj_n->NodeType(), ir::Node::Type::kOperation,
platform::errors::InvalidArgument(
"Node(%s)'s type(%d) must be kOperation type.", adj_n->Name(),
static_cast<int>(adj_n->NodeType())));
VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n) VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n) << " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var); << " via " << var->Name() << reinterpret_cast<void *>(var);
...@@ -184,7 +190,11 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationOutAdjList( ...@@ -184,7 +190,11 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationOutAdjList(
} }
for (auto &var : n->outputs) { for (auto &var : n->outputs) {
for (auto &adj_n : var->outputs) { for (auto &adj_n : var->outputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); PADDLE_ENFORCE_EQ(
adj_n->NodeType(), ir::Node::Type::kOperation,
platform::errors::InvalidArgument(
"Node(%s)'s type(%d) must be kOperation type.", adj_n->Name(),
static_cast<int>(adj_n->NodeType())));
VLOG(40) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n) VLOG(40) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n) << " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var); << " via " << var->Name() << reinterpret_cast<void *>(var);
...@@ -359,7 +369,10 @@ size_t GraphNum(const Graph &graph) { ...@@ -359,7 +369,10 @@ size_t GraphNum(const Graph &graph) {
} }
std::unique_ptr<std::ostream> fout( std::unique_ptr<std::ostream> fout(
new std::ofstream(FLAGS_print_sub_graph_dir)); new std::ofstream(FLAGS_print_sub_graph_dir));
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE_EQ(fout->good(), true,
platform::errors::Unavailable(
"Can not open file %s for printing the graph.",
FLAGS_print_sub_graph_dir));
*fout << out.str(); *fout << out.str();
} }
} }
......
...@@ -37,12 +37,14 @@ NodesDFSIterator::NodesDFSIterator(const NodesDFSIterator &other) ...@@ -37,12 +37,14 @@ NodesDFSIterator::NodesDFSIterator(const NodesDFSIterator &other)
: stack_(other.stack_), visited_(other.visited_) {} : stack_(other.stack_), visited_(other.visited_) {}
Node &NodesDFSIterator::operator*() { Node &NodesDFSIterator::operator*() {
PADDLE_ENFORCE(!stack_.empty()); PADDLE_ENFORCE_EQ(stack_.empty(), false, platform::errors::OutOfRange(
"The iterator exceeds range."));
return *stack_.top(); return *stack_.top();
} }
NodesDFSIterator &NodesDFSIterator::operator++() { NodesDFSIterator &NodesDFSIterator::operator++() {
PADDLE_ENFORCE(!stack_.empty(), "the iterator exceeds range"); PADDLE_ENFORCE_EQ(stack_.empty(), false, platform::errors::OutOfRange(
"The iterator exceeds range."));
visited_.insert(stack_.top()); visited_.insert(stack_.top());
auto *cur = stack_.top(); auto *cur = stack_.top();
stack_.pop(); stack_.pop();
...@@ -73,11 +75,18 @@ inline bool CheckNodeIndegreeEquals(const Node &node, size_t n) { ...@@ -73,11 +75,18 @@ inline bool CheckNodeIndegreeEquals(const Node &node, size_t n) {
} }
NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) { NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) {
PADDLE_ENFORCE(!source.empty(), PADDLE_ENFORCE_EQ(
"Start points of topological sorting should not be empty!"); source.empty(), false,
platform::errors::InvalidArgument(
"Start points of topological sorting should not be empty!"));
// CHECK all the inputs' in-degree is 0 // CHECK all the inputs' in-degree is 0
for (auto *node : source) { for (auto *node : source) {
PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0)); PADDLE_ENFORCE_EQ(
CheckNodeIndegreeEquals(*node, 0), true,
platform::errors::InvalidArgument(
"In start points of topological sorting, the indegree of each "
"point should be 0. Node(%s)'s indegree is not 0.",
node->Name()));
} }
std::set<Node *> to_visit{source.begin(), source.end()}; std::set<Node *> to_visit{source.begin(), source.end()};
...@@ -106,7 +115,11 @@ NodesTSIterator::NodesTSIterator(const NodesTSIterator &other) ...@@ -106,7 +115,11 @@ NodesTSIterator::NodesTSIterator(const NodesTSIterator &other)
: sorted_(other.sorted_), cursor_(other.cursor_) {} : sorted_(other.sorted_), cursor_(other.cursor_) {}
Node &NodesTSIterator::operator*() { Node &NodesTSIterator::operator*() {
PADDLE_ENFORCE_LT(cursor_, sorted_.size()); PADDLE_ENFORCE_LT(
cursor_, sorted_.size(),
platform::errors::OutOfRange(
"The iterator exceeds range. Container size is %d, but index is %d.",
sorted_.size(), cursor_));
return *sorted_[cursor_]; return *sorted_[cursor_];
} }
...@@ -128,7 +141,11 @@ bool NodesTSIterator::operator==(const NodesTSIterator &other) { ...@@ -128,7 +141,11 @@ bool NodesTSIterator::operator==(const NodesTSIterator &other) {
} }
Node *NodesTSIterator::operator->() { Node *NodesTSIterator::operator->() {
PADDLE_ENFORCE_LT(cursor_, sorted_.size()); PADDLE_ENFORCE_LT(
cursor_, sorted_.size(),
platform::errors::OutOfRange(
"The iterator exceeds range. Container size is %d, but index is %d.",
sorted_.size(), cursor_));
return sorted_[cursor_]; return sorted_[cursor_];
} }
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#pragma once #pragma once
#include <stack> #include <stack>
#include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -66,7 +68,7 @@ struct NodesDFSIterator ...@@ -66,7 +68,7 @@ struct NodesDFSIterator
struct NodesTSIterator struct NodesTSIterator
: public std::iterator<std::forward_iterator_tag, Node *> { : public std::iterator<std::forward_iterator_tag, Node *> {
NodesTSIterator() = default; NodesTSIterator() = default;
NodesTSIterator(const std::vector<Node *> &source); explicit NodesTSIterator(const std::vector<Node *> &source);
NodesTSIterator(NodesTSIterator &&other) NodesTSIterator(NodesTSIterator &&other)
: sorted_(std::move(other.sorted_)), cursor_(other.cursor_) { : sorted_(std::move(other.sorted_)), cursor_(other.cursor_) {
other.cursor_ = 0; other.cursor_ = 0;
...@@ -104,7 +106,10 @@ struct GraphTraits { ...@@ -104,7 +106,10 @@ struct GraphTraits {
static iterator_range<NodesTSIterator> TS(const Graph &g) { static iterator_range<NodesTSIterator> TS(const Graph &g) {
auto start_points = ExtractStartPoints(g); auto start_points = ExtractStartPoints(g);
PADDLE_ENFORCE(!start_points.empty()); PADDLE_ENFORCE_EQ(
start_points.empty(), false,
platform::errors::InvalidArgument(
"Start points of topological sorting should not be empty!"));
NodesTSIterator x(start_points); NodesTSIterator x(start_points);
return iterator_range<NodesTSIterator>(NodesTSIterator(start_points), return iterator_range<NodesTSIterator>(NodesTSIterator(start_points),
NodesTSIterator()); NodesTSIterator());
......
...@@ -42,7 +42,10 @@ void GraphVizPass::ApplyImpl(ir::Graph* graph) const { ...@@ -42,7 +42,10 @@ void GraphVizPass::ApplyImpl(ir::Graph* graph) const {
const std::string& graph_viz_path = Get<std::string>(kGraphvizPath); const std::string& graph_viz_path = Get<std::string>(kGraphvizPath);
VLOG(3) << "draw IR graph viz to " << graph_viz_path; VLOG(3) << "draw IR graph viz to " << graph_viz_path;
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path)); std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE_EQ(
fout->good(), true,
platform::errors::Unavailable(
"Can not open file %s for printing the graph.", graph_viz_path));
std::ostream& sout = *fout; std::ostream& sout = *fout;
std::unordered_map<const ir::Node*, std::string> node2dot; std::unordered_map<const ir::Node*, std::string> node2dot;
......
...@@ -64,7 +64,11 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const { ...@@ -64,7 +64,11 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const {
for (auto& parameter : *pre_op_desc->Proto()->mutable_outputs()) { for (auto& parameter : *pre_op_desc->Proto()->mutable_outputs()) {
auto* arguments = parameter.mutable_arguments(); auto* arguments = parameter.mutable_arguments();
auto it = std::find(arguments->begin(), arguments->end(), scale_in_name); auto it = std::find(arguments->begin(), arguments->end(), scale_in_name);
PADDLE_ENFORCE(it != arguments->end()); PADDLE_ENFORCE_NE(
it, arguments->end(),
platform::errors::NotFound(
"Can not find input variable(%s) from scale op(%s).",
scale_in_name, pre_op_desc->Type()));
*it = scale_out_name; *it = scale_out_name;
} }
......
...@@ -33,7 +33,8 @@ const char kSumGradOpName[] = "sum"; ...@@ -33,7 +33,8 @@ const char kSumGradOpName[] = "sum";
const char kOptimizerType[] = "sgd"; const char kOptimizerType[] = "sgd";
void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const { void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
// We could collect all weights' name from SGD, where // We could collect all weights' name from SGD, where
// W1 <- SGD(W0, Grad0) // W1 <- SGD(W0, Grad0)
...@@ -41,7 +42,10 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -41,7 +42,10 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (IsOpNamed(node, kOptimizerType)) { if (IsOpNamed(node, kOptimizerType)) {
auto& param_out_vars = node->Op()->Output("ParamOut"); auto& param_out_vars = node->Op()->Output("ParamOut");
PADDLE_ENFORCE(param_out_vars.size() == 1u); PADDLE_ENFORCE_EQ(
param_out_vars.size(), 1u,
platform::errors::InvalidArgument(
"In op(%s), find output(ParamOut) failed.", node->Name()));
weight_var_set.insert(param_out_vars[0]); weight_var_set.insert(param_out_vars[0]);
} }
} }
...@@ -95,12 +99,19 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -95,12 +99,19 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Found forward_op " << forward_op->Name(); VLOG(3) << "Found forward_op " << forward_op->Name();
PADDLE_ENFORCE(forward_op); PADDLE_ENFORCE_NOT_NULL(
forward_op, platform::errors::NotFound(
"Can not find forward op for backword op(%s).",
backward_op->Name()));
Node* new_optimizer_node = CreateNewSGDNode( Node* new_optimizer_node = CreateNewSGDNode(
graph, forward_op, backward_op, node, opt_node); graph, forward_op, backward_op, node, opt_node);
PADDLE_ENFORCE(new_optimizer_node); PADDLE_ENFORCE_NOT_NULL(
new_optimizer_node,
platform::errors::InvalidArgument(
"Create new SGD node failed, backward op is %s.",
backward_op->Name()));
} }
} }
} }
...@@ -144,11 +155,21 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -144,11 +155,21 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
ir::Node* LockFreeOptimizePass::CreateNewSGDNode( ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
ir::Graph* graph, ir::Node* forward_node, ir::Node* backward_node, ir::Graph* graph, ir::Node* forward_node, ir::Node* backward_node,
ir::Node* grad_sum_node, ir::Node* optimize_node) const { ir::Node* grad_sum_node, ir::Node* optimize_node) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(graph,
PADDLE_ENFORCE(forward_node); platform::errors::InvalidArgument(
PADDLE_ENFORCE(backward_node); "Input argument graph cannot be nullptr."));
PADDLE_ENFORCE(grad_sum_node); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(optimize_node); forward_node, platform::errors::InvalidArgument(
"Input argument forward_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
backward_node, platform::errors::InvalidArgument(
"Input argument backward_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
grad_sum_node, platform::errors::InvalidArgument(
"Input argument grad_sum_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
optimize_node, platform::errors::InvalidArgument(
"Input argument optimize_node cannot be nullptr."));
// find the grad var node between the grad sum node and backward_node // find the grad var node between the grad sum node and backward_node
std::vector<ir::Node*> grad_vars = std::vector<ir::Node*> grad_vars =
...@@ -159,7 +180,8 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode( ...@@ -159,7 +180,8 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
grad_node = node; grad_node = node;
} }
} }
PADDLE_ENFORCE(grad_node); PADDLE_ENFORCE_NOT_NULL(grad_node, platform::errors::NotFound(
"Can not find control dep variable."));
// create a new SGD node // create a new SGD node
OpDesc* old_desc = optimize_node->Op(); OpDesc* old_desc = optimize_node->Op();
...@@ -212,8 +234,14 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode( ...@@ -212,8 +234,14 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
} }
// SGD must have only one param and LR in // SGD must have only one param and LR in
PADDLE_ENFORCE(old_desc->Input("LearningRate").size() == 1u); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE(old_desc->Input("Param").size() == 1u); old_desc->Input("LearningRate").size(), 1u,
platform::errors::InvalidArgument(
"In op(%s), find input(LearningRate) failed.", old_desc->Type()));
PADDLE_ENFORCE_EQ(
old_desc->Input("Param").size(), 1u,
platform::errors::InvalidArgument("In op(%s), find input(Param) failed.",
old_desc->Type()));
// LR and weight nodes should be copied // LR and weight nodes should be copied
for (Node* upstream_node : optimize_node->inputs) { for (Node* upstream_node : optimize_node->inputs) {
...@@ -245,9 +273,17 @@ std::vector<ir::Node*> LockFreeOptimizePass::FindConnectedNode( ...@@ -245,9 +273,17 @@ std::vector<ir::Node*> LockFreeOptimizePass::FindConnectedNode(
void LockFreeOptimizePass::ReplaceUpstreamNode( void LockFreeOptimizePass::ReplaceUpstreamNode(
ir::Node* upstream_node, ir::Node* old_optimizer_node, ir::Node* upstream_node, ir::Node* old_optimizer_node,
ir::Node* new_optimizer_node) const { ir::Node* new_optimizer_node) const {
PADDLE_ENFORCE(upstream_node); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(old_optimizer_node); upstream_node, platform::errors::InvalidArgument(
PADDLE_ENFORCE(new_optimizer_node); "Input argument upstream_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
old_optimizer_node,
platform::errors::InvalidArgument(
"Input argument old_optimizer_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
new_optimizer_node,
platform::errors::InvalidArgument(
"Input argument new_optimizer_node cannot be nullptr."));
// Remove the old_optimizer_node from upstream_node's outputs vector // Remove the old_optimizer_node from upstream_node's outputs vector
auto& output_node_vec = upstream_node->outputs; auto& output_node_vec = upstream_node->outputs;
...@@ -268,8 +304,14 @@ void LockFreeOptimizePass::ReplaceUpstreamNode( ...@@ -268,8 +304,14 @@ void LockFreeOptimizePass::ReplaceUpstreamNode(
void LockFreeOptimizePass::ReplaceAllDownstreamNode( void LockFreeOptimizePass::ReplaceAllDownstreamNode(
ir::Node* old_optimizer_node, ir::Node* new_optimizer_node) const { ir::Node* old_optimizer_node, ir::Node* new_optimizer_node) const {
PADDLE_ENFORCE(old_optimizer_node); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(new_optimizer_node); old_optimizer_node,
platform::errors::InvalidArgument(
"Input argument old_optimizer_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
new_optimizer_node,
platform::errors::InvalidArgument(
"Input argument new_optimizer_node cannot be nullptr."));
for (ir::Node* downstream_node : old_optimizer_node->outputs) { for (ir::Node* downstream_node : old_optimizer_node->outputs) {
// Remove the old_optimizer_node from downstream_node's inputs vector // Remove the old_optimizer_node from downstream_node's inputs vector
...@@ -292,8 +334,12 @@ void LockFreeOptimizePass::ReplaceAllDownstreamNode( ...@@ -292,8 +334,12 @@ void LockFreeOptimizePass::ReplaceAllDownstreamNode(
ir::Node* LockFreeOptimizePass::FindForwardOpViaBackwardOp( ir::Node* LockFreeOptimizePass::FindForwardOpViaBackwardOp(
ir::Graph* graph, ir::Node* backward_node) const { ir::Graph* graph, ir::Node* backward_node) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(graph,
PADDLE_ENFORCE(backward_node); platform::errors::InvalidArgument(
"Input argument graph cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
backward_node, platform::errors::InvalidArgument(
"Input argument backward_node cannot be nullptr."));
// strip the suffix _grad of backward_node's name // strip the suffix _grad of backward_node's name
std::string forward_op_name = backward_node->Name(); std::string forward_op_name = backward_node->Name();
......
...@@ -87,34 +87,46 @@ class LockFreeOptimizePass : public Pass { ...@@ -87,34 +87,46 @@ class LockFreeOptimizePass : public Pass {
ir::Node* downstream_node) const; ir::Node* downstream_node) const;
inline bool IsOpNamed(ir::Node* node, const std::string& name) const { inline bool IsOpNamed(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node); PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return node->NodeType() == Node::Type::kOperation && node->Name() == name; return node->NodeType() == Node::Type::kOperation && node->Name() == name;
} }
inline bool IsVarNamed(ir::Node* node, const std::string& name) const { inline bool IsVarNamed(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node); PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return node->NodeType() == Node::Type::kVariable && node->Name() == name; return node->NodeType() == Node::Type::kVariable && node->Name() == name;
} }
inline bool IsVarNameEndsWith(ir::Node* node, const std::string& name) const { inline bool IsVarNameEndsWith(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node); PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return node->NodeType() == Node::Type::kVariable && return node->NodeType() == Node::Type::kVariable &&
boost::algorithm::ends_with(node->Name(), name); boost::algorithm::ends_with(node->Name(), name);
} }
inline bool IsVarNameContains(ir::Node* node, const std::string& name) const { inline bool IsVarNameContains(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node); PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return node->NodeType() == Node::Type::kVariable && return node->NodeType() == Node::Type::kVariable &&
node->Name().find(name) != std::string::npos; node->Name().find(name) != std::string::npos;
} }
inline bool IsControlDepFrom(ir::Node* ctrl_dep_node, ir::Node* node) const { inline bool IsControlDepFrom(ir::Node* ctrl_dep_node, ir::Node* node) const {
PADDLE_ENFORCE(ctrl_dep_node); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(node); ctrl_dep_node, platform::errors::InvalidArgument(
"Input argument ctrl_dep_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return IsControlDepVar(*ctrl_dep_node) && return IsControlDepVar(*ctrl_dep_node) &&
ctrl_dep_node->inputs.size() >= 1u && ctrl_dep_node->inputs.size() >= 1u &&
......
...@@ -116,7 +116,10 @@ std::vector<OpHandleBase *> BufferSharedCrossOpMemoryReusePass::SortOp( ...@@ -116,7 +116,10 @@ std::vector<OpHandleBase *> BufferSharedCrossOpMemoryReusePass::SortOp(
graph_view.BreadthFirstVisit( graph_view.BreadthFirstVisit(
[&](OpHandleBase *cur_op) { sorted_ops.emplace_back(cur_op); }); [&](OpHandleBase *cur_op) { sorted_ops.emplace_back(cur_op); });
PADDLE_ENFORCE_EQ(sorted_ops.size(), graph_view.OpNumber(), PADDLE_ENFORCE_EQ(sorted_ops.size(), graph_view.OpNumber(),
"There are unvisited ops"); platform::errors::InvalidArgument(
"Sorted ops size(%d) not equal to graph op size(%d). "
"There are unvisited ops.",
sorted_ops.size(), graph_view.OpNumber()));
return sorted_ops; return sorted_ops;
} }
...@@ -181,7 +184,9 @@ void BufferSharedCrossOpMemoryReusePass::RunOnScopeIdx(size_t idx) const { ...@@ -181,7 +184,9 @@ void BufferSharedCrossOpMemoryReusePass::RunOnScopeIdx(size_t idx) const {
auto *out_node = *(out_nodes.begin()); auto *out_node = *(out_nodes.begin());
auto *out_var = auto *out_var =
dynamic_cast<VarHandle *>(&(out_node->Wrapper<VarHandleBase>())); dynamic_cast<VarHandle *>(&(out_node->Wrapper<VarHandleBase>()));
PADDLE_ENFORCE_NOT_NULL(out_var); PADDLE_ENFORCE_NOT_NULL(
out_var, platform::errors::NotFound(
"Can not find a valid Var Node for Var %s.", out_arg));
// If out_arg is not reusable, skip it // If out_arg is not reusable, skip it
if (!IsOutVarReusable(*out_var)) { if (!IsOutVarReusable(*out_var)) {
...@@ -269,7 +274,8 @@ size_t BufferSharedCrossOpMemoryReusePass::ResolveDependencyBetween( ...@@ -269,7 +274,8 @@ size_t BufferSharedCrossOpMemoryReusePass::ResolveDependencyBetween(
auto op_dep = GetOpDep(prev_op, op); auto op_dep = GetOpDep(prev_op, op);
if (op_dep == NodeDependency::kBefore) continue; if (op_dep == NodeDependency::kBefore) continue;
PADDLE_ENFORCE_EQ(op_dep, NodeDependency::kNoDep, PADDLE_ENFORCE_EQ(op_dep, NodeDependency::kNoDep,
"The graph has circle, this may be a bug"); platform::errors::InvalidArgument(
"The graph has circle, this may be a bug."));
auto iter = auto iter =
std::find_if(prev_op->Outputs().begin(), prev_op->Outputs().end(), std::find_if(prev_op->Outputs().begin(), prev_op->Outputs().end(),
...@@ -316,9 +322,13 @@ size_t BufferSharedCrossOpMemoryReusePass::ResolveDependencyBetween( ...@@ -316,9 +322,13 @@ size_t BufferSharedCrossOpMemoryReusePass::ResolveDependencyBetween(
} }
void BufferSharedCrossOpMemoryReusePass::BuildOpDependencyMap() const { void BufferSharedCrossOpMemoryReusePass::BuildOpDependencyMap() const {
PADDLE_ENFORCE(ops_.empty(), "ops_ must be initialized here"); PADDLE_ENFORCE_EQ(ops_.empty(), true, platform::errors::InvalidArgument(
PADDLE_ENFORCE(op_to_idx_.empty(), "op_to_idx_ must be initialized here"); "Ops must be initialized here."));
PADDLE_ENFORCE(deps_.empty(), "deps_ must be initialized here"); PADDLE_ENFORCE_EQ(
op_to_idx_.empty(), true,
platform::errors::InvalidArgument("Op to idx must be initialized here."));
PADDLE_ENFORCE_EQ(deps_.empty(), true, platform::errors::InvalidArgument(
"Deps must be initialized here."));
// Toposort ops // Toposort ops
OpGraphView graph_view(ir::FilterByNodeWrapper<OpHandleBase>(*graph_)); OpGraphView graph_view(ir::FilterByNodeWrapper<OpHandleBase>(*graph_));
...@@ -344,7 +354,10 @@ void BufferSharedCrossOpMemoryReusePass::BuildOpDependencyMap() const { ...@@ -344,7 +354,10 @@ void BufferSharedCrossOpMemoryReusePass::BuildOpDependencyMap() const {
prev_preceding_ops.end()); prev_preceding_ops.end());
} }
}); });
PADDLE_ENFORCE_EQ(preceding_ops.size(), op_num); PADDLE_ENFORCE_EQ(preceding_ops.size(), op_num,
platform::errors::InvalidArgument(
"Preceding ops size(%d) must equal to op num(%d).",
preceding_ops.size(), op_num));
// Find out ComputationOpHandles only // Find out ComputationOpHandles only
ops_.resize(scope_num); ops_.resize(scope_num);
...@@ -384,28 +397,43 @@ void BufferSharedCrossOpMemoryReusePass::BuildOpDependencyMap() const { ...@@ -384,28 +397,43 @@ void BufferSharedCrossOpMemoryReusePass::BuildOpDependencyMap() const {
size_t BufferSharedCrossOpMemoryReusePass::OpIndex( size_t BufferSharedCrossOpMemoryReusePass::OpIndex(
const ComputationOpHandle *op) const { const ComputationOpHandle *op) const {
auto iter = op_to_idx_[op->GetScopeIdx()].find(op); auto iter = op_to_idx_[op->GetScopeIdx()].find(op);
PADDLE_ENFORCE(iter != op_to_idx_[op->GetScopeIdx()].end()); PADDLE_ENFORCE_NE(iter, op_to_idx_[op->GetScopeIdx()].end(),
platform::errors::NotFound(
"Can not find op(%s) in op_to_idx_.", op->Name()));
return iter->second; return iter->second;
} }
NodeDependency BufferSharedCrossOpMemoryReusePass::GetOpDep( NodeDependency BufferSharedCrossOpMemoryReusePass::GetOpDep(
const ComputationOpHandle *op1, const ComputationOpHandle *op2) const { const ComputationOpHandle *op1, const ComputationOpHandle *op2) const {
PADDLE_ENFORCE_EQ(op1->GetScopeIdx(), op2->GetScopeIdx()); PADDLE_ENFORCE_EQ(op1->GetScopeIdx(), op2->GetScopeIdx(),
platform::errors::InvalidArgument(
"Op(%s) and op(%s) must in the same scope.",
op1->Name(), op2->Name()));
return deps_[op1->GetScopeIdx()][OpIndex(op1)][OpIndex(op2)]; return deps_[op1->GetScopeIdx()][OpIndex(op1)][OpIndex(op2)];
} }
void BufferSharedCrossOpMemoryReusePass::SetOpDep( void BufferSharedCrossOpMemoryReusePass::SetOpDep(
const ComputationOpHandle *op1, const ComputationOpHandle *op2, const ComputationOpHandle *op1, const ComputationOpHandle *op2,
NodeDependency dep) const { NodeDependency dep) const {
PADDLE_ENFORCE_EQ(op1->GetScopeIdx(), op2->GetScopeIdx()); PADDLE_ENFORCE_EQ(op1->GetScopeIdx(), op2->GetScopeIdx(),
platform::errors::InvalidArgument(
"Op(%s) and op(%s) must in the same scope.",
op1->Name(), op2->Name()));
if (op1 == op2) { if (op1 == op2) {
PADDLE_ENFORCE(dep == NodeDependency::kSame); PADDLE_ENFORCE_EQ(
dep, NodeDependency::kSame,
platform::errors::InvalidArgument(
"Set Same Op(%s) Dep, dep must be kSame type.", op1->Name()));
auto idx = OpIndex(op1); auto idx = OpIndex(op1);
deps_[op1->GetScopeIdx()][idx][idx] = NodeDependency::kSame; deps_[op1->GetScopeIdx()][idx][idx] = NodeDependency::kSame;
} else { } else {
auto idx1 = OpIndex(op1); auto idx1 = OpIndex(op1);
auto idx2 = OpIndex(op2); auto idx2 = OpIndex(op2);
PADDLE_ENFORCE(dep != NodeDependency::kSame && idx1 != idx2); PADDLE_ENFORCE_EQ((dep != NodeDependency::kSame && idx1 != idx2), true,
platform::errors::InvalidArgument(
"Op(%s) and Op(%s) should not have same "
"index(%d), and dep should not kSame type.",
op1->Name(), op2->Name(), idx1));
deps_[op1->GetScopeIdx()][idx1][idx2] = dep; deps_[op1->GetScopeIdx()][idx1][idx2] = dep;
deps_[op1->GetScopeIdx()][idx2][idx1] = ReverseNodeDependency(dep); deps_[op1->GetScopeIdx()][idx2][idx1] = ReverseNodeDependency(dep);
} }
......
...@@ -57,7 +57,9 @@ void BufferSharedInplaceOpPass::Run(Graph *graph) const { ...@@ -57,7 +57,9 @@ void BufferSharedInplaceOpPass::Run(Graph *graph) const {
auto *op = *(pair.second.ops().begin()); auto *op = *(pair.second.ops().begin());
const std::string &op_type = op->GetOp()->Type(); const std::string &op_type = op->GetOp()->Type();
const framework::OpDesc *op_desc = op->Node()->Op(); const framework::OpDesc *op_desc = op->Node()->Op();
PADDLE_ENFORCE_NOT_NULL(op_desc); PADDLE_ENFORCE_NOT_NULL(
op_desc, platform::errors::NotFound("Op(%s) can not find opdesc.",
op->Name()));
auto &infer_inplace = OpInfoMap::Instance().Get(op_type).infer_inplace_; auto &infer_inplace = OpInfoMap::Instance().Get(op_type).infer_inplace_;
if (!infer_inplace) { if (!infer_inplace) {
......
...@@ -58,8 +58,12 @@ static int64_t GetMemorySize( ...@@ -58,8 +58,12 @@ static int64_t GetMemorySize(
&vars, &vars,
const std::string &var_name) { const std::string &var_name) {
auto *var_desc = TryGetLatestVarDesc(vars.at(var_name)); auto *var_desc = TryGetLatestVarDesc(vars.at(var_name));
PADDLE_ENFORCE_NOT_NULL(var_desc); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(IsLoDTensor(var_desc)); var_desc,
platform::errors::NotFound("Var(%s) can not find VarDesc.", var_name));
PADDLE_ENFORCE_EQ(IsLoDTensor(var_desc), true,
platform::errors::InvalidArgument(
"Var(%s) must be LoDTensor.", var_name));
auto dims = var_desc->GetShape(); auto dims = var_desc->GetShape();
return SizeOfType(var_desc->GetDataType()) * return SizeOfType(var_desc->GetDataType()) *
std::accumulate(dims.begin(), dims.end(), static_cast<int64_t>(1), std::accumulate(dims.begin(), dims.end(), static_cast<int64_t>(1),
......
...@@ -42,8 +42,10 @@ class MemOptVarInfo { ...@@ -42,8 +42,10 @@ class MemOptVarInfo {
} }
void SetRefCnt(size_t ref_cnt) { void SetRefCnt(size_t ref_cnt) {
PADDLE_ENFORCE_GE(ref_cnt, 1, PADDLE_ENFORCE_GE(
"Reference count must be larger than or equal to 1"); ref_cnt, 1,
platform::errors::InvalidArgument(
"Reference count(%d) must be larger than or equal to 1.", ref_cnt));
ref_cnt_ = ref_cnt; ref_cnt_ = ref_cnt;
runtime_ref_cnt_ = ref_cnt; runtime_ref_cnt_ = ref_cnt;
} }
......
...@@ -66,7 +66,11 @@ bool MemoryReusePass::TryReuseVar(details::VarHandle *in_var, ...@@ -66,7 +66,11 @@ bool MemoryReusePass::TryReuseVar(details::VarHandle *in_var,
details::VarHandle *out_var) const { details::VarHandle *out_var) const {
auto *op = auto *op =
dynamic_cast<details::ComputationOpHandle *>(out_var->GeneratedOp()); dynamic_cast<details::ComputationOpHandle *>(out_var->GeneratedOp());
PADDLE_ENFORCE_NOT_NULL(op); PADDLE_ENFORCE_NOT_NULL(
op,
platform::errors::InvalidArgument(
"Var(%s) have no GeneratedOp, or it's op is not ComputationOpHandle.",
out_var->Name()));
if (IsVarPairReusable(*in_var, *out_var)) { if (IsVarPairReusable(*in_var, *out_var)) {
AddReuseVar(op, in_var, out_var); AddReuseVar(op, in_var, out_var);
return true; return true;
...@@ -91,10 +95,13 @@ VarDesc *MemoryReusePass::GetVarDesc(const details::VarHandle &var) const { ...@@ -91,10 +95,13 @@ VarDesc *MemoryReusePass::GetVarDesc(const details::VarHandle &var) const {
size_t scope_idx = var.scope_idx(); size_t scope_idx = var.scope_idx();
auto iter = var_descs_[scope_idx].find(var_name); auto iter = var_descs_[scope_idx].find(var_name);
if (iter == var_descs_[scope_idx].end()) { if (iter == var_descs_[scope_idx].end()) {
PADDLE_ENFORCE((*all_vars_)[scope_idx].count(var_name), PADDLE_ENFORCE_NE(
"Variable %s not found", var_name); (*all_vars_)[scope_idx].count(var_name), 0,
platform::errors::NotFound("Variable %s not found.", var_name));
auto *desc = TryGetLatestVarDesc((*all_vars_)[scope_idx].at(var_name)); auto *desc = TryGetLatestVarDesc((*all_vars_)[scope_idx].at(var_name));
PADDLE_ENFORCE_NOT_NULL(desc); PADDLE_ENFORCE_NOT_NULL(
desc,
platform::errors::NotFound("Var(%s) can not find VarDesc.", var_name));
var_descs_[scope_idx].emplace(var_name, desc); var_descs_[scope_idx].emplace(var_name, desc);
return desc; return desc;
} else { } else {
...@@ -119,7 +126,9 @@ void MemoryReusePass::CollectShareTensorBufferOpHandles() const { ...@@ -119,7 +126,9 @@ void MemoryReusePass::CollectShareTensorBufferOpHandles() const {
if (share_buffer_op != nullptr) { if (share_buffer_op != nullptr) {
auto *compute_op = auto *compute_op =
details::GetUniquePendingComputationOpHandle(share_buffer_op); details::GetUniquePendingComputationOpHandle(share_buffer_op);
PADDLE_ENFORCE(ops_.count(compute_op) == 0); PADDLE_ENFORCE_EQ(
ops_.count(compute_op), 0,
platform::errors::AlreadyExists("Compute op already exists."));
ops_.emplace(compute_op, share_buffer_op); ops_.emplace(compute_op, share_buffer_op);
} }
} }
...@@ -227,8 +236,11 @@ bool MemoryReusePass::IsInVarReusable(const details::VarHandle &in_var) const { ...@@ -227,8 +236,11 @@ bool MemoryReusePass::IsInVarReusable(const details::VarHandle &in_var) const {
*/ */
bool MemoryReusePass::IsOutVarReusable( bool MemoryReusePass::IsOutVarReusable(
const details::VarHandle &out_var) const { const details::VarHandle &out_var) const {
PADDLE_ENFORCE_NOT_NULL(dynamic_cast<const details::ComputationOpHandle *>( PADDLE_ENFORCE_NOT_NULL(
out_var.GeneratedOp())); dynamic_cast<const details::ComputationOpHandle *>(out_var.GeneratedOp()),
platform::errors::InvalidArgument(
"Var(%s) have no GeneratedOp, or it's op is not ComputationOpHandle.",
out_var.Name()));
const auto out_name = out_var.Name(); const auto out_name = out_var.Name();
if (out_name == kEmptyVarName) { if (out_name == kEmptyVarName) {
return false; return false;
...@@ -236,9 +248,10 @@ bool MemoryReusePass::IsOutVarReusable( ...@@ -236,9 +248,10 @@ bool MemoryReusePass::IsOutVarReusable(
// out_var must be the first version!!! // out_var must be the first version!!!
auto out_var_iter = (*all_vars_)[out_var.scope_idx()].find(out_name); auto out_var_iter = (*all_vars_)[out_var.scope_idx()].find(out_name);
PADDLE_ENFORCE(out_var_iter != (*all_vars_)[out_var.scope_idx()].end() && PADDLE_ENFORCE_EQ(
!out_var_iter->second.empty(), (out_var_iter != (*all_vars_)[out_var.scope_idx()].end() &&
"Cannot find variable %s", out_name); !out_var_iter->second.empty()),
true, platform::errors::NotFound("Cannot find variable %s.", out_name));
if (out_var_iter->second[0] != &out_var) { if (out_var_iter->second[0] != &out_var) {
return false; return false;
...@@ -282,7 +295,11 @@ bool MemoryReusePass::IsVarPairReusable( ...@@ -282,7 +295,11 @@ bool MemoryReusePass::IsVarPairReusable(
const details::VarHandle &in_var, const details::VarHandle &out_var) const { const details::VarHandle &in_var, const details::VarHandle &out_var) const {
auto *op = auto *op =
dynamic_cast<const details::ComputationOpHandle *>(out_var.GeneratedOp()); dynamic_cast<const details::ComputationOpHandle *>(out_var.GeneratedOp());
PADDLE_ENFORCE_NOT_NULL(op); PADDLE_ENFORCE_NOT_NULL(
op,
platform::errors::InvalidArgument(
"Var(%s) have no GeneratedOp, or it's op is not ComputationOpHandle.",
out_var.Name()));
const auto in_name = in_var.Name(); const auto in_name = in_var.Name();
if (in_name == out_var.Name()) { if (in_name == out_var.Name()) {
...@@ -308,8 +325,10 @@ bool MemoryReusePass::IsVarPairReusable( ...@@ -308,8 +325,10 @@ bool MemoryReusePass::IsVarPairReusable(
void MemoryReusePass::AddReuseVar(details::ComputationOpHandle *op, void MemoryReusePass::AddReuseVar(details::ComputationOpHandle *op,
details::VarHandle *in_var, details::VarHandle *in_var,
details::VarHandle *out_var) const { details::VarHandle *out_var) const {
PADDLE_ENFORCE((*var_infos_)[op->GetScopeIdx()].count(in_var->Name()) > 0, PADDLE_ENFORCE_GT(
"%s does not in mem-opt var infos", in_var->Name()); (*var_infos_)[op->GetScopeIdx()].count(in_var->Name()), 0,
platform::errors::NotFound("Var(%s) does not in mem opt var infos.",
in_var->Name()));
if (ops_.count(op) == 0) { if (ops_.count(op) == 0) {
InsertShareTensorBufferOpHandleToGraph(op); InsertShareTensorBufferOpHandleToGraph(op);
...@@ -349,7 +368,10 @@ void MemoryReusePass::UpdateLastLiveOpOfVar(details::ComputationOpHandle *op, ...@@ -349,7 +368,10 @@ void MemoryReusePass::UpdateLastLiveOpOfVar(details::ComputationOpHandle *op,
if (out_var_op_iter == (*last_live_ops_of_vars_)[scope_idx].end()) { if (out_var_op_iter == (*last_live_ops_of_vars_)[scope_idx].end()) {
last_live_op_of_in_var = op; last_live_op_of_in_var = op;
} else { } else {
PADDLE_ENFORCE(!out_var_op_iter->second.ops().empty()); PADDLE_ENFORCE_EQ(
out_var_op_iter->second.ops().empty(), false,
platform::errors::InvalidArgument(
"Var(%s)'s last live op should not empty.", out_var->Name()));
last_live_op_of_in_var = *(out_var_op_iter->second.ops().begin()); last_live_op_of_in_var = *(out_var_op_iter->second.ops().begin());
} }
...@@ -359,8 +381,9 @@ void MemoryReusePass::UpdateLastLiveOpOfVar(details::ComputationOpHandle *op, ...@@ -359,8 +381,9 @@ void MemoryReusePass::UpdateLastLiveOpOfVar(details::ComputationOpHandle *op,
last_live_ops_of_in_var->insert(last_live_op_of_in_var); last_live_ops_of_in_var->insert(last_live_op_of_in_var);
auto in_var_info_iter = (*var_infos_)[scope_idx].find(in_var->Name()); auto in_var_info_iter = (*var_infos_)[scope_idx].find(in_var->Name());
PADDLE_ENFORCE(in_var_info_iter != (*var_infos_)[scope_idx].end(), PADDLE_ENFORCE_NE(
"Cannot find variable %s", in_var->Name()); in_var_info_iter, (*var_infos_)[scope_idx].end(),
platform::errors::NotFound("Cannot find variable %s.", in_var->Name()));
in_var_info_iter->second->SetRefCnt(1); in_var_info_iter->second->SetRefCnt(1);
} }
......
...@@ -39,7 +39,7 @@ void OpGraphView::Build(const std::vector<details::OpHandleBase *> &ops) { ...@@ -39,7 +39,7 @@ void OpGraphView::Build(const std::vector<details::OpHandleBase *> &ops) {
} }
PADDLE_ENFORCE( PADDLE_ENFORCE(
preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(), preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(),
"There are duplicate ops in graph."); platform::errors::InvalidArgument("There are duplicate ops in graph."));
} }
std::unordered_set<details::OpHandleBase *> OpGraphView::AllOps() const { std::unordered_set<details::OpHandleBase *> OpGraphView::AllOps() const {
...@@ -56,8 +56,10 @@ bool OpGraphView::HasOp(details::OpHandleBase *op) const { ...@@ -56,8 +56,10 @@ bool OpGraphView::HasOp(details::OpHandleBase *op) const {
} }
void OpGraphView::EnforceHasOp(details::OpHandleBase *op) const { void OpGraphView::EnforceHasOp(details::OpHandleBase *op) const {
PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView", PADDLE_ENFORCE_EQ(HasOp(op), true,
op == nullptr ? "nullptr" : op->DebugString()); platform::errors::NotFound(
"Cannot find op %s in OpGraphView.",
op == nullptr ? "nullptr" : op->DebugString()));
} }
const std::unordered_set<details::OpHandleBase *> &OpGraphView::PendingOps( const std::unordered_set<details::OpHandleBase *> &OpGraphView::PendingOps(
......
...@@ -127,9 +127,13 @@ void OpGraphView::BreadthFirstVisit(Callback &&callback) const { ...@@ -127,9 +127,13 @@ void OpGraphView::BreadthFirstVisit(Callback &&callback) const {
} }
} }
PADDLE_ENFORCE_EQ(num_calls, op_num, "There are unvisited ops"); PADDLE_ENFORCE_EQ(num_calls, op_num, platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(visited_ops.size(), op_num, "There are unvisited ops"); "There are unvisited ops."));
PADDLE_ENFORCE(op_deps.empty(), "There are unvisited ops"); PADDLE_ENFORCE_EQ(
visited_ops.size(), op_num,
platform::errors::InvalidArgument("There are unvisited ops."));
PADDLE_ENFORCE_EQ(op_deps.empty(), true, platform::errors::InvalidArgument(
"There are unvisited ops."));
} }
} // namespace ir } // namespace ir
......
...@@ -77,11 +77,15 @@ class ShrinkDepsOpFunctor { ...@@ -77,11 +77,15 @@ class ShrinkDepsOpFunctor {
const std::vector<details::OpHandleBase *> &ops) const { const std::vector<details::OpHandleBase *> &ops) const {
std::unordered_map<details::OpHandleBase *, size_t> op_to_idx; std::unordered_map<details::OpHandleBase *, size_t> op_to_idx;
for (size_t i = 0; i < ops.size(); ++i) { for (size_t i = 0; i < ops.size(); ++i) {
PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph"); PADDLE_ENFORCE_EQ(
graph_.HasOp(ops[i]), true,
platform::errors::InvalidArgument("Op does not exist in graph."));
op_to_idx[ops[i]] = i; op_to_idx[ops[i]] = i;
} }
PADDLE_ENFORCE(op_to_idx.size() == ops.size(), "Duplicate ops"); PADDLE_ENFORCE_EQ(
op_to_idx.size(), ops.size(),
platform::errors::InvalidArgument("Graph may have duplicate ops."));
std::vector<std::vector<RelationShip>> ret(ops.size()); std::vector<std::vector<RelationShip>> ret(ops.size());
for (auto &e : ret) { for (auto &e : ret) {
...@@ -247,9 +251,9 @@ ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx, ...@@ -247,9 +251,9 @@ ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx,
return {}; return {};
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(computation_ops.empty(), false,
computation_ops.empty(), false, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("Computation ops should not be empty")); "Computation ops should not be empty."));
// stage four. Try to shrink computation op if they depend on each other. // stage four. Try to shrink computation op if they depend on each other.
// Get the smallest set of the most ops. // Get the smallest set of the most ops.
...@@ -263,8 +267,9 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ...@@ -263,8 +267,9 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars); Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
PADDLE_ENFORCE(last_live_ops_of_vars.empty() && var_infos.empty(), PADDLE_ENFORCE(last_live_ops_of_vars.empty() && var_infos.empty(),
"Last Live Ops and Reference Counts of vars should be " platform::errors::InvalidArgument(
"initialized at here."); "Last live ops and reference counts of vars should be "
"initialized at here."));
const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars); const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars);
...@@ -304,11 +309,15 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ...@@ -304,11 +309,15 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
auto &var_name = name_var_pair.first; auto &var_name = name_var_pair.first;
auto &var_handles = name_var_pair.second; auto &var_handles = name_var_pair.second;
PADDLE_ENFORCE_EQ(var_desc->Name(), var_name);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
var_handles.empty(), false, var_desc->Name(), var_name,
platform::errors::InvalidArgument("Variable %s not found", var_name)); platform::errors::InvalidArgument(
"A Var, it's VarName(%s) and DescName(%s) not same.", var_name,
var_desc->Name()));
PADDLE_ENFORCE_EQ(var_handles.empty(), false,
platform::errors::InvalidArgument(
"Variable %s not found.", var_name));
auto last_ver_var = var_handles.back(); auto last_ver_var = var_handles.back();
if (last_ver_var->Node()->IsCtrlVar()) { if (last_ver_var->Node()->IsCtrlVar()) {
...@@ -327,12 +336,13 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ...@@ -327,12 +336,13 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
continue; continue;
} }
PADDLE_ENFORCE_EQ(status, LastLiveOpSearchStatus::kSuccess,
platform::errors::InvalidArgument(
"Status(%d) must be success.", status));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
status, LastLiveOpSearchStatus::kSuccess, result.empty(), false,
platform::errors::InvalidArgument("status must be success")); platform::errors::NotFound("Last living ops of %s cannot be empty.",
PADDLE_ENFORCE_EQ(result.empty(), false, var_name));
platform::errors::NotFound(
"Last living ops of %s cannot be empty", var_name));
std::string last_live_ops_log_str; std::string last_live_ops_log_str;
for (auto &each_ret : result) { for (auto &each_ret : result) {
......
...@@ -22,7 +22,8 @@ namespace framework { ...@@ -22,7 +22,8 @@ namespace framework {
namespace ir { namespace ir {
void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph, "graph cannot be nullptr."); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("conv_activation_mkldnn_fuse", graph); FusePassBase::Init("conv_activation_mkldnn_fuse", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -75,7 +76,8 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -75,7 +76,8 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
GraphSafeRemoveNodes(graph, {activation, conv_out}); GraphSafeRemoveNodes(graph, {activation, conv_out});
PADDLE_ENFORCE_GT(subgraph.count(conv_input), 0UL, PADDLE_ENFORCE_GT(subgraph.count(conv_input), 0UL,
"subgraph has to contain conv_input node."); platform::errors::InvalidArgument(
"Subgraph has to contain conv input node."));
IR_NODE_LINK_TO(conv, activation_out); IR_NODE_LINK_TO(conv, activation_out);
found_conv_activation_count++; found_conv_activation_count++;
}; };
......
...@@ -26,7 +26,11 @@ namespace ir { ...@@ -26,7 +26,11 @@ namespace ir {
template <typename BinaryOperation> template <typename BinaryOperation>
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b, LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
BinaryOperation f) { BinaryOperation f) {
PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims()); PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims(),
platform::errors::InvalidArgument(
"Input two tensors must have same shape, but they are "
"different: %s, %s.",
vec_a.dims(), vec_b.dims()));
LoDTensor vec_y; LoDTensor vec_y;
vec_y.Resize(vec_a.dims()); vec_y.Resize(vec_a.dims());
const float* a = vec_a.data<float>(); const float* a = vec_a.data<float>();
...@@ -39,11 +43,13 @@ LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b, ...@@ -39,11 +43,13 @@ LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
} }
void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = auto* conv_input =
...@@ -68,7 +74,9 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -68,7 +74,9 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
// elementwise_add op // elementwise_add op
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern);
PADDLE_ENFORCE(subgraph.count(conv_input)); PADDLE_ENFORCE_NE(
subgraph.count(conv_input), 0,
platform::errors::NotFound("Detector did not find conv input."));
// check if fuse can be done and if MKL-DNN should be used // check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise); FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
...@@ -86,10 +94,16 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -86,10 +94,16 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
if (has_bias && conv->Op()->Input("Bias").size() > 0) { if (has_bias && conv->Op()->Input("Bias").size() > 0) {
auto conv_bias_names = conv->Op()->Input("Bias"); auto conv_bias_names = conv->Op()->Input("Bias");
// add eltwise bias to existing conv bias // add eltwise bias to existing conv bias
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1); PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1,
platform::errors::NotFound("Can not find var Bias."));
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]); auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>(); auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), eltwise_bias_tensor->dims()); PADDLE_ENFORCE_EQ(
conv_bias_tensor->dims(), eltwise_bias_tensor->dims(),
platform::errors::InvalidArgument(
"Conv bias tensor and eltwise bias tensor "
"must have same shape, but they are different: %s, %s.",
conv_bias_tensor->dims(), eltwise_bias_tensor->dims()));
*conv_bias_tensor = tensor_apply_eltwise( *conv_bias_tensor = tensor_apply_eltwise(
*conv_bias_tensor, *eltwise_bias_tensor, std::plus<float>()); *conv_bias_tensor, *eltwise_bias_tensor, std::plus<float>());
......
...@@ -39,7 +39,10 @@ void ConvConcatReLUFusePass::FindConcatWithConvs( ...@@ -39,7 +39,10 @@ void ConvConcatReLUFusePass::FindConcatWithConvs(
for (auto node : concat_inputs) { for (auto node : concat_inputs) {
auto prev_op_node = node->inputs; auto prev_op_node = node->inputs;
PADDLE_ENFORCE_EQ(prev_op_node.size(), 1); PADDLE_ENFORCE_EQ(prev_op_node.size(), 1,
platform::errors::InvalidArgument(
"Node(%s) input size(%d) must be 1.", node->Name(),
prev_op_node.size()));
auto* conv_op = prev_op_node[0]; auto* conv_op = prev_op_node[0];
if (conv_op->Op()->Type() != "conv2d") return; if (conv_op->Op()->Type() != "conv2d") return;
...@@ -103,7 +106,8 @@ void ConvConcatReLUFusePass::FuseConvConcatReLU( ...@@ -103,7 +106,8 @@ void ConvConcatReLUFusePass::FuseConvConcatReLU(
} }
void ConvConcatReLUFusePass::ApplyImpl(ir::Graph* graph) const { void ConvConcatReLUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
std::unordered_map<const Node*, int> concat_with_convs_counter; std::unordered_map<const Node*, int> concat_with_convs_counter;
......
...@@ -68,10 +68,10 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input, ...@@ -68,10 +68,10 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input,
auto inputs = op->Op()->InputNames(); auto inputs = op->Op()->InputNames();
bool name_found = bool name_found =
std::find(inputs.begin(), inputs.end(), input_name) != inputs.end(); std::find(inputs.begin(), inputs.end(), input_name) != inputs.end();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(name_found, true,
name_found, true, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("%s isn't the input of the %s operator", "Var(%s) isn't the input of the %s operator.",
input_name, op->Op()->Type())); input_name, op->Op()->Type()));
unsigned max = is_unsigned ? U8_MAX : S8_MAX; unsigned max = is_unsigned ? U8_MAX : S8_MAX;
float scale = scale_to_one * max; float scale = scale_to_one * max;
...@@ -110,8 +110,14 @@ void CPUQuantizePass::QuantizeInputs(Graph* g, Node* op, std::string input_name, ...@@ -110,8 +110,14 @@ void CPUQuantizePass::QuantizeInputs(Graph* g, Node* op, std::string input_name,
std::string scale_attr_name) const { std::string scale_attr_name) const {
auto inputs = op->inputs; auto inputs = op->inputs;
auto output = op->outputs[0]; auto output = op->outputs[0];
PADDLE_ENFORCE_GE(inputs.size(), 1); PADDLE_ENFORCE_GE(inputs.size(), 1,
PADDLE_ENFORCE_EQ(op->outputs.size(), 1); platform::errors::InvalidArgument(
"OP(%s)'s inputs(%d) must be equal or greater than 1.",
op->Name(), inputs.size()));
PADDLE_ENFORCE_EQ(op->outputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s outputs(%d) must be equal to 1.", op->Name(),
op->outputs.size()));
// create a quantize op desc prototype // create a quantize op desc prototype
OpDesc q_desc; OpDesc q_desc;
...@@ -159,8 +165,8 @@ void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output, ...@@ -159,8 +165,8 @@ void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output,
std::find(outputs.begin(), outputs.end(), output_name) != outputs.end(); std::find(outputs.begin(), outputs.end(), output_name) != outputs.end();
PADDLE_ENFORCE_EQ(name_found, true, PADDLE_ENFORCE_EQ(name_found, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"%s isn't the output of the %s operator", output_name, "Var(%s) isn't the output of the %s operator.",
op->Op()->Type())); output_name, op->Op()->Type()));
unsigned max = is_unsigned ? U8_MAX : S8_MAX; unsigned max = is_unsigned ? U8_MAX : S8_MAX;
float scale = scale_to_one * max; float scale = scale_to_one * max;
...@@ -682,10 +688,12 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { ...@@ -682,10 +688,12 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
bool is_x_unsigned{false}, is_y_unsigned{false}; bool is_x_unsigned{false}, is_y_unsigned{false};
auto input_x_scale = GetScaleValueForNode(matmul_in_x, &is_x_unsigned); auto input_x_scale = GetScaleValueForNode(matmul_in_x, &is_x_unsigned);
auto input_y_scale = GetScaleValueForNode(matmul_in_y, &is_y_unsigned); auto input_y_scale = GetScaleValueForNode(matmul_in_y, &is_y_unsigned);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(is_x_unsigned, is_y_unsigned,
is_x_unsigned, is_y_unsigned, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "Matmul inputs should have the same "
"Matmul inputs should have the same value of is_unsigned")); "attribute of signed/unsigned, but they "
"are different: x(%d), y(%d).",
is_x_unsigned, is_y_unsigned));
QuantizeInput(g, matmul_op, matmul_in_x, "X", input_x_scale, is_x_unsigned, QuantizeInput(g, matmul_op, matmul_in_x, "X", input_x_scale, is_x_unsigned,
"Scale_x"); "Scale_x");
QuantizeInput(g, matmul_op, matmul_in_y, "Y", input_y_scale, is_y_unsigned, QuantizeInput(g, matmul_op, matmul_in_y, "Y", input_y_scale, is_y_unsigned,
...@@ -785,10 +793,12 @@ void CPUQuantizePass::QuantizeElementwiseAdd(Graph* graph) const { ...@@ -785,10 +793,12 @@ void CPUQuantizePass::QuantizeElementwiseAdd(Graph* graph) const {
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Quantizing the graph."; VLOG(3) << "Quantizing the graph.";
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
PADDLE_ENFORCE(param_scope()); PADDLE_ENFORCE_NOT_NULL(param_scope(), platform::errors::InvalidArgument(
"Scope cannot be nullptr."));
QuantizeConv(graph, false /* with_residual_data */); QuantizeConv(graph, false /* with_residual_data */);
QuantizeConv(graph, true /* with_residual_data */); QuantizeConv(graph, true /* with_residual_data */);
......
...@@ -75,7 +75,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash( ...@@ -75,7 +75,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
BOOST_GET_CONST(float, quant_op->Op()->GetAttr("Scale")); BOOST_GET_CONST(float, quant_op->Op()->GetAttr("Scale"));
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
nodes_keep_counter->find(dequant_out), nodes_keep_counter->end(), nodes_keep_counter->find(dequant_out), nodes_keep_counter->end(),
platform::errors::NotFound("The dequant output node is not found")); platform::errors::NotFound("The dequant output node is not found."));
// check if dequantize op should be kept or removed, decrease the counter // check if dequantize op should be kept or removed, decrease the counter
bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1; bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1;
...@@ -153,8 +153,9 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const { ...@@ -153,8 +153,9 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
any_op_output_name.empty(), true, any_op_output_name.empty(), true,
platform::errors::NotFound("Operator before requantize operator " platform::errors::NotFound("Operator before requantize operator(%s) "
"should have requantize input as output")); "should have requantize input as output.",
requant_in->Name()));
float requant_scale_out = float requant_scale_out =
BOOST_GET_CONST(float, requant_op->Op()->GetAttr("Scale_out")); BOOST_GET_CONST(float, requant_op->Op()->GetAttr("Scale_out"));
...@@ -195,10 +196,11 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const { ...@@ -195,10 +196,11 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
for (auto input_name : any_op->Op()->Input(name)) for (auto input_name : any_op->Op()->Input(name))
if (input_name == requant_out->Name()) any_op_input_name = name; if (input_name == requant_out->Name()) any_op_input_name = name;
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(any_op_input_name.empty(), true,
any_op_input_name.empty(), true, platform::errors::NotFound(
platform::errors::NotFound("The operator after requantize operator " "The operator after requantize operator(%s) "
"should have requantize output as input")); "should have requantize output as input.",
requant_out->Name()));
float requant_scale_in = float requant_scale_in =
boost::get<float>(requant_op->Op()->GetAttr("Scale_in")); boost::get<float>(requant_op->Op()->GetAttr("Scale_in"));
...@@ -206,11 +208,14 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const { ...@@ -206,11 +208,14 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
if (any_op->Op()->Type() == "matmul") if (any_op->Op()->Type() == "matmul")
scale_name = any_op_input_name == "X" ? "Scale_x" : "Scale_y"; scale_name = any_op_input_name == "X" ? "Scale_x" : "Scale_y";
PADDLE_ENFORCE_EQ(requant_op->Op()->GetAttrIfExists<float>("Scale_out"), PADDLE_ENFORCE_EQ(
any_op->Op()->GetAttrIfExists<float>(scale_name), requant_op->Op()->GetAttrIfExists<float>("Scale_out"),
platform::errors::InvalidArgument( any_op->Op()->GetAttrIfExists<float>(scale_name),
"The operator after requantize should have input " platform::errors::InvalidArgument(
"scale equal to requantize output scale")); "The operator after requantize should have input "
"scale(%f) equal to requantize output scale(%f).",
any_op->Op()->GetAttrIfExists<float>(scale_name),
requant_op->Op()->GetAttrIfExists<float>("Scale_out")));
any_op->Op()->SetAttr(scale_name, requant_scale_in); any_op->Op()->SetAttr(scale_name, requant_scale_in);
any_op->Op()->SetInput(any_op_input_name, any_op->Op()->SetInput(any_op_input_name,
std::vector<std::string>({requant_in->Name()})); std::vector<std::string>({requant_in->Name()}));
...@@ -286,8 +291,9 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { ...@@ -286,8 +291,9 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
auto* first_quant_out = first_quant_op->outputs[0]; auto* first_quant_out = first_quant_op->outputs[0];
float scale = first_quant_op->Op()->GetAttrIfExists<float>("Scale"); float scale = first_quant_op->Op()->GetAttrIfExists<float>("Scale");
PADDLE_ENFORCE_NE(scale, 0, platform::errors::InvalidArgument( PADDLE_ENFORCE_NE(scale, 0,
"Quantize scale should not be equal 0")); platform::errors::InvalidArgument(
"Quantize scale(%f) should not be equal 0.", scale));
for (int iter = prev_out->outputs.size() - 1; iter >= 0; iter--) { for (int iter = prev_out->outputs.size() - 1; iter >= 0; iter--) {
auto quant_op = prev_out->outputs[iter]; auto quant_op = prev_out->outputs[iter];
...@@ -304,8 +310,9 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { ...@@ -304,8 +310,9 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
last_op_input_name.empty(), true, last_op_input_name.empty(), true,
platform::errors::NotFound("Operator after quantize operator " platform::errors::NotFound("Operator after quantize operator(%s) "
"should has quantize output as input")); "should has quantize output as input.",
quant_out->Name()));
last_op->Op()->SetInput( last_op->Op()->SetInput(
last_op_input_name, last_op_input_name,
std::vector<std::string>({first_quant_out->Name()})); std::vector<std::string>({first_quant_out->Name()}));
...@@ -345,10 +352,12 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const { ...@@ -345,10 +352,12 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
PADDLE_ENFORCE_GT(dequant_scale, 0.0f, PADDLE_ENFORCE_GT(dequant_scale, 0.0f,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Dequantize scale should have positive value")); "Dequantize scale(%f) should have positive value.",
dequant_scale));
PADDLE_ENFORCE_GT(scale_scale, 0.0f, PADDLE_ENFORCE_GT(scale_scale, 0.0f,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Scale of scale op should have positive value")); "Scale(%f) of scale op should have positive value.",
scale_scale));
dequant_op->Op()->SetAttr("Scale", dequant_scale / scale_scale); dequant_op->Op()->SetAttr("Scale", dequant_scale / scale_scale);
dequant_op->Op()->SetOutput( dequant_op->Op()->SetOutput(
...@@ -367,8 +376,8 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const { ...@@ -367,8 +376,8 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, graph,
platform::errors::NotFound( platform::errors::InvalidArgument(
"The graph in function CPUQuantizeSquashPass::ApplyImpl is null")); "The graph in function CPUQuantizeSquashPass::ApplyImpl is null."));
FusePassBase::Init("cpu_quantize_squash_pass", graph); FusePassBase::Init("cpu_quantize_squash_pass", graph);
std::unordered_map<const Node*, int> nodes_keep_counter; std::unordered_map<const Node*, int> nodes_keep_counter;
......
...@@ -57,7 +57,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -57,7 +57,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
PADDLE_ENFORCE_EQ(inputs.size(), 2UL, PADDLE_ENFORCE_EQ(inputs.size(), 2UL,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The fc inputs should contain input and weights, but " "The fc inputs should contain input and weights, but "
"now the size of inputs is %d", "now the size of inputs is %d.",
inputs.size())); inputs.size()));
op->SetInput("W", {inputs[1]}); op->SetInput("W", {inputs[1]});
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
......
...@@ -19,14 +19,17 @@ namespace paddle { ...@@ -19,14 +19,17 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
#define GET_NODE(id, pattern) \ #define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetrieveNode(#id)), \ PADDLE_ENFORCE_NE(subgraph.count(pattern.RetrieveNode(#id)), 0, \
"pattern has no Node called %s", #id); \ platform::errors::InvalidArgument( \
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \ "Pattern has no Node called %s.", #id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL( \
id, platform::errors::InvalidArgument("Subgraph has no node %s.", #id));
void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const { void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("depthwise_conv_mkldnn_pass", graph); FusePassBase::Init("depthwise_conv_mkldnn_pass", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
......
...@@ -46,12 +46,15 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -46,12 +46,15 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const {
if (scale_op->Op()->GetAttrIfExists<float>("bias") == 0.0) { if (scale_op->Op()->GetAttrIfExists<float>("bias") == 0.0) {
auto matmul_alpha = matmul_op->Op()->GetAttrIfExists<float>("alpha"); auto matmul_alpha = matmul_op->Op()->GetAttrIfExists<float>("alpha");
auto scale_scale = scale_op->Op()->GetAttrIfExists<float>("scale"); auto scale_scale = scale_op->Op()->GetAttrIfExists<float>("scale");
PADDLE_ENFORCE_GT(matmul_alpha, 0.0f, PADDLE_ENFORCE_GT(
platform::errors::InvalidArgument( matmul_alpha, 0.0f,
"Alpha of matmul op should have positive value")); platform::errors::InvalidArgument(
"Alpha(%f) of matmul op should have positive value.",
matmul_alpha));
PADDLE_ENFORCE_GT(scale_scale, 0.0f, PADDLE_ENFORCE_GT(scale_scale, 0.0f,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Scale of scale op should have positive value")); "Scale(%f) of scale op should have positive value.",
scale_scale));
std::string matmul_op_input_name; std::string matmul_op_input_name;
for (auto name : matmul_op->Op()->InputNames()) for (auto name : matmul_op->Op()->InputNames())
...@@ -60,8 +63,9 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -60,8 +63,9 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
matmul_op_input_name.empty(), true, matmul_op_input_name.empty(), true,
platform::errors::NotFound("Operator after scale operator " platform::errors::NotFound("Operator after scale operator(%s) "
"should have scale output as input")); "should have scale output as input.",
scale_out->Name()));
matmul_op->Op()->SetAttr("alpha", matmul_alpha * scale_scale); matmul_op->Op()->SetAttr("alpha", matmul_alpha * scale_scale);
matmul_op->Op()->SetInput(matmul_op_input_name, matmul_op->Op()->SetInput(matmul_op_input_name,
std::vector<std::string>({scale_in->Name()})); std::vector<std::string>({scale_in->Name()}));
......
...@@ -85,7 +85,9 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const { ...@@ -85,7 +85,9 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const {
// 1. record op nodes of different roles // 1. record op nodes of different roles
for (auto node : nodes) { for (auto node : nodes) {
if (!node->IsOp()) continue; if (!node->IsOp()) continue;
PADDLE_ENFORCE(node->Op(), "must find opdesc"); PADDLE_ENFORCE_NOT_NULL(
node->Op(), platform::errors::InvalidArgument(
"Node(%s) must hold op description.", node->Name()));
int op_role = BOOST_GET_CONST( int op_role = BOOST_GET_CONST(
int, node->Op()->GetAttr( int, node->Op()->GetAttr(
framework::OpProtoAndCheckerMaker::OpRoleAttrName())); framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
...@@ -108,7 +110,9 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const { ...@@ -108,7 +110,9 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const {
} else if (op_role & static_cast<int>(framework::OpRole::kLRSched)) { } else if (op_role & static_cast<int>(framework::OpRole::kLRSched)) {
lr_ops.push_back(node); lr_ops.push_back(node);
} else { // NOLINT } else { // NOLINT
PADDLE_THROW("Invalid op_role: %d", static_cast<int>(op_role)); PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid op role(%d), in node(%s).", static_cast<int>(op_role),
node->Name()));
} }
} }
......
...@@ -45,7 +45,9 @@ class AllReduceDepsPass : public ir::Pass { ...@@ -45,7 +45,9 @@ class AllReduceDepsPass : public ir::Pass {
for (size_t i = 0; i < all_reduce_op_handles.size(); ++i) { for (size_t i = 0; i < all_reduce_op_handles.size(); ++i) {
auto op_handle = auto op_handle =
dynamic_cast<details::NCCLOpHandleBase*>(all_reduce_op_handles[i]); dynamic_cast<details::NCCLOpHandleBase*>(all_reduce_op_handles[i]);
PADDLE_ENFORCE(op_handle, "op_handle must be NCCLOpHandleBase"); PADDLE_ENFORCE_NOT_NULL(op_handle,
platform::errors::InvalidArgument(
"Op handle must be NCCLOpHandleBase."));
op_handle->SetRunEnv(i, use_hierarchical_allreduce); op_handle->SetRunEnv(i, use_hierarchical_allreduce);
} }
#endif #endif
...@@ -95,7 +97,9 @@ class AllReduceDepsPass : public ir::Pass { ...@@ -95,7 +97,9 @@ class AllReduceDepsPass : public ir::Pass {
} }
} }
PADDLE_ENFORCE_NE(next_ready_ops.size(), 0, "There maybe have a cycle."); PADDLE_ENFORCE_NE(
next_ready_ops.size(), 0,
platform::errors::InvalidArgument("There may be a cycle."));
ready_ops.clear(); ready_ops.clear();
std::swap(ready_ops, next_ready_ops); std::swap(ready_ops, next_ready_ops);
GetSortedAllReduceOps(ready_ops, &all_reduce_op_handles); GetSortedAllReduceOps(ready_ops, &all_reduce_op_handles);
...@@ -122,18 +126,25 @@ class AllReduceDepsPass : public ir::Pass { ...@@ -122,18 +126,25 @@ class AllReduceDepsPass : public ir::Pass {
// NOTE(zcd): For distributed training, it is important to keep the order of // NOTE(zcd): For distributed training, it is important to keep the order of
// allReduce on each node consistent. Otherwise, hang may occur. // allReduce on each node consistent. Otherwise, hang may occur.
// Sort the current_all_reduce_op_handles according to the name of input. // Sort the current_all_reduce_op_handles according to the name of input.
sort(current_all_reduce_op_handles.begin(), sort(
current_all_reduce_op_handles.end(), current_all_reduce_op_handles.begin(),
[](const details::OpHandleBase* left, current_all_reduce_op_handles.end(),
const details::OpHandleBase* right) -> bool { [](const details::OpHandleBase* left,
auto left_in_vars = const details::OpHandleBase* right) -> bool {
details::DynamicCast<details::VarHandle>(left->Inputs()); auto left_in_vars =
auto right_in_vars = details::DynamicCast<details::VarHandle>(left->Inputs());
details::DynamicCast<details::VarHandle>(right->Inputs()); auto right_in_vars =
PADDLE_ENFORCE_GT(left_in_vars.size(), 0); details::DynamicCast<details::VarHandle>(right->Inputs());
PADDLE_ENFORCE_GT(right_in_vars.size(), 0); PADDLE_ENFORCE_GT(left_in_vars.size(), 0,
return left_in_vars[0]->Name() > right_in_vars[0]->Name(); platform::errors::InvalidArgument(
}); "OpHandle(%s) inputs size must greater than 0.",
left->Name()));
PADDLE_ENFORCE_GT(right_in_vars.size(), 0,
platform::errors::InvalidArgument(
"OpHandle(%s) inputs size must greater than 0.",
right->Name()));
return left_in_vars[0]->Name() > right_in_vars[0]->Name();
});
all_reduce_op_handles->insert(all_reduce_op_handles->end(), all_reduce_op_handles->insert(all_reduce_op_handles->end(),
current_all_reduce_op_handles.begin(), current_all_reduce_op_handles.begin(),
...@@ -170,7 +181,10 @@ class AllReduceDepsPass : public ir::Pass { ...@@ -170,7 +181,10 @@ class AllReduceDepsPass : public ir::Pass {
break; break;
} }
} }
PADDLE_ENFORCE(find_valid_input, "Doesn't find valid input."); PADDLE_ENFORCE_EQ(
find_valid_input, true,
platform::errors::NotFound(
"In OpHandle(%s) Doesn't find valid input.", op->Name()));
} }
VLOG(10) << out2.str(); VLOG(10) << out2.str();
if (grads_of_stale_program != all_reduce_op_handles.size()) { if (grads_of_stale_program != all_reduce_op_handles.size()) {
......
...@@ -179,9 +179,10 @@ class BackWardOpDepsPass : public ir::Pass { ...@@ -179,9 +179,10 @@ class BackWardOpDepsPass : public ir::Pass {
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
auto backward_vars = details::GetOpRoleVarsOrEmpty(op_desc); auto backward_vars = details::GetOpRoleVarsOrEmpty(op_desc);
PADDLE_ENFORCE_EQ(node->IsWrappedBy<details::OpHandleBase>(), true, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( node->IsWrappedBy<details::OpHandleBase>(), true,
"Node must be wrapped by OpHandleBase")); platform::errors::InvalidArgument(
"Node(%s) must be wrapped by OpHandleBase.", node->Name()));
backward_op_handles->emplace_back(&node->Wrapper<details::OpHandleBase>()); backward_op_handles->emplace_back(&node->Wrapper<details::OpHandleBase>());
......
...@@ -64,9 +64,10 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -64,9 +64,10 @@ class FuseAllReduceOpPass : public ir::Pass {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
all_reduce_ops.size(), grads.size(), all_reduce_ops.size(), grads.size(),
platform::errors::Unimplemented( platform::errors::Unimplemented(
"The number of all_reduce OpHandle is not equal to the " "The number of all_reduce OpHandle(%d) is not equal to the "
"number of grads. Maybe some gradients are sparse type, " "number of grads(%d). Maybe some gradients are sparse type, "
"it is not supported currently.")); "it is not supported currently.",
all_reduce_ops.size(), grads.size()));
auto &group_params_grads = graph->Get<details::GroupParamsAndGrads>( auto &group_params_grads = graph->Get<details::GroupParamsAndGrads>(
details::kGroupParamsAndDenseGrads); details::kGroupParamsAndDenseGrads);
...@@ -79,7 +80,10 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -79,7 +80,10 @@ class FuseAllReduceOpPass : public ir::Pass {
for (auto &group_p_g : group_params_grads) { for (auto &group_p_g : group_params_grads) {
size_t group_size = group_p_g.size(); size_t group_size = group_p_g.size();
PADDLE_ENFORCE_GT(group_size, static_cast<size_t>(0)); PADDLE_ENFORCE_GT(
group_size, static_cast<size_t>(0),
platform::errors::InvalidArgument(
"Parameter and Parameter@grad in one group, must not be empty."));
std::vector<ir::Node *> group_all_reduce_ops; std::vector<ir::Node *> group_all_reduce_ops;
group_all_reduce_ops.reserve(group_size); group_all_reduce_ops.reserve(group_size);
for (auto &p_g : group_p_g) { for (auto &p_g : group_p_g) {
...@@ -103,26 +107,40 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -103,26 +107,40 @@ class FuseAllReduceOpPass : public ir::Pass {
all_reduce_ops.reserve(grads.size()); all_reduce_ops.reserve(grads.size());
for (auto &node : result.Nodes()) { for (auto &node : result.Nodes()) {
if (node->IsOp()) { if (node->IsOp()) {
PADDLE_ENFORCE(node->IsWrappedBy<details::OpHandleBase>()); PADDLE_ENFORCE_EQ(
node->IsWrappedBy<details::OpHandleBase>(), true,
platform::errors::InvalidArgument(
"Op Node(%s) should Wrapped by OpHandleBase.", node->Name()));
auto *all_reduce_op_handle = dynamic_cast<details::AllReduceOpHandle *>( auto *all_reduce_op_handle = dynamic_cast<details::AllReduceOpHandle *>(
&node->Wrapper<details::OpHandleBase>()); &node->Wrapper<details::OpHandleBase>());
if (all_reduce_op_handle) { if (all_reduce_op_handle) {
#if defined(PADDLE_WITH_DGC) #if defined(PADDLE_WITH_DGC)
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
all_reduce_op_handle->Name(), "sparse_all_reduce", all_reduce_op_handle->Name(), "sparse_all_reduce",
"DGC doesn't support fuse for now, if you want to use DGC " platform::errors::InvalidArgument(
"you need set strategy.fuse_all_reduce_ops = False."); "DGC doesn't support fuse for now, if you want to use DGC "
"you need set strategy.fuse_all_reduce_ops = False."));
#endif #endif
auto inputs = details::DynamicCast<details::VarHandle>( auto inputs = details::DynamicCast<details::VarHandle>(
all_reduce_op_handle->Inputs()); all_reduce_op_handle->Inputs());
PADDLE_ENFORCE_EQ(inputs.size(), num_place); PADDLE_ENFORCE_EQ(inputs.size(), num_place,
platform::errors::InvalidArgument(
"The input size(%d) of all reduce op must "
"equal to place cnt(%d)!",
inputs.size(), num_place));
// The inputs' name should be the same. // The inputs' name should be the same.
auto &grad_name = inputs[0]->name(); auto &grad_name = inputs[0]->name();
for (size_t i = 1; i < inputs.size(); ++i) { for (size_t i = 1; i < inputs.size(); ++i) {
PADDLE_ENFORCE_EQ(inputs[i]->name(), grad_name, PADDLE_ENFORCE_EQ(
"The input name should be the same."); inputs[i]->name(), grad_name,
platform::errors::InvalidArgument(
"The input name should be the same.diff name: %s %s.",
inputs[i]->name(), grad_name));
} }
PADDLE_ENFORCE_NE(grads.count(grad_name), static_cast<size_t>(0)); PADDLE_ENFORCE_NE(
grads.count(grad_name), static_cast<size_t>(0),
platform::errors::InvalidArgument(
"Parameter@grad(%s) must in grad set.", grad_name));
all_reduce_ops.emplace(grad_name, node); all_reduce_ops.emplace(grad_name, node);
} }
} }
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册