提交 d8260d9c 编写于 作者: Z zlsh80826

merge develop, test=develop

...@@ -218,6 +218,9 @@ endif(WITH_AMD_GPU) ...@@ -218,6 +218,9 @@ endif(WITH_AMD_GPU)
if(WITH_ARM) if(WITH_ARM)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(WITH_XBYAK OFF CACHE STRING "Disable XBYAK when compiling WITH_ARM=ON" FORCE)
set(WITH_MKL OFF CACHE STRING "Disable MKL when compiling WITH_ARM=ON." FORCE)
set(WITH_GPU OFF CACHE STRING "Disable GPU when compiling WITH_ARM=ON." FORCE)
add_definitions(-DPADDLE_WITH_ARM) add_definitions(-DPADDLE_WITH_ARM)
endif() endif()
......
 
# PaddlePaddle <p align="center">
<img align="center" src="doc/imgs/logo.png", width=1600>
<p>
--------------------------------------------------------------------------------
English | [简体中文](./README_cn.md) English | [简体中文](./README_cn.md)
...@@ -29,7 +33,7 @@ pip install paddlepaddle ...@@ -29,7 +33,7 @@ pip install paddlepaddle
# Linux GPU cuda10cudnn7 # Linux GPU cuda10cudnn7
pip install paddlepaddle-gpu pip install paddlepaddle-gpu
# Linux GPU cuda9cudnn7 # Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu==1.8.2.post97 pip install paddlepaddle-gpu==1.8.3.post97
``` ```
It is recommended to read [this doc](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/install/index_en.html) on our website. It is recommended to read [this doc](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/install/index_en.html) on our website.
......
 
# PaddlePaddle <p align="center">
<img align="center" src="doc/imgs/logo.png", width=1600>
<p>
--------------------------------------------------------------------------------
[English](./README.md) | 简体中文 [English](./README.md) | 简体中文
...@@ -26,7 +30,7 @@ pip install paddlepaddle ...@@ -26,7 +30,7 @@ pip install paddlepaddle
# Linux GPU cuda10cudnn7 # Linux GPU cuda10cudnn7
pip install paddlepaddle-gpu pip install paddlepaddle-gpu
# Linux GPU cuda9cudnn7 # Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu==1.8.2.post97 pip install paddlepaddle-gpu==1.8.3.post97
``` ```
更多安装信息详见官网 [安装说明](http://www.paddlepaddle.org.cn/documentation/docs/zh/1.8/beginners_guide/install/index_cn.html) 更多安装信息详见官网 [安装说明](http://www.paddlepaddle.org.cn/documentation/docs/zh/1.8/beginners_guide/install/index_cn.html)
......
...@@ -18,6 +18,15 @@ if(NOT LINUX OR NOT WITH_MKL) ...@@ -18,6 +18,15 @@ if(NOT LINUX OR NOT WITH_MKL)
return() return()
endif() endif()
if(XPU_SDK_ROOT)
set(LITE_WITH_XPU ON)
include_directories("${XPU_SDK_ROOT}/XTDK/include")
include_directories("${XPU_SDK_ROOT}/XTCL/include")
add_definitions(-DPADDLE_WITH_XPU)
LINK_DIRECTORIES("${XPU_SDK_ROOT}/XTDK/shlib/")
LINK_DIRECTORIES("${XPU_SDK_ROOT}/XTDK/runtime/shlib/")
endif()
if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
include(ExternalProject) include(ExternalProject)
set(LITE_PROJECT extern_lite) set(LITE_PROJECT extern_lite)
...@@ -25,7 +34,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) ...@@ -25,7 +34,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite) set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite)
if(NOT LITE_GIT_TAG) if(NOT LITE_GIT_TAG)
set(LITE_GIT_TAG ab8af5c4b4dc5b40217633e0aa436315912d7b53) set(LITE_GIT_TAG 42ab4d559f6659edfc35040fb30fdcec3dc3f8aa)
endif() endif()
if(NOT CUDA_ARCH_NAME) if(NOT CUDA_ARCH_NAME)
...@@ -47,6 +56,8 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) ...@@ -47,6 +56,8 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
-DCUDNN_ROOT=${CUDNN_ROOT} -DCUDNN_ROOT=${CUDNN_ROOT}
-DLITE_WITH_STATIC_CUDA=OFF -DLITE_WITH_STATIC_CUDA=OFF
-DCUDA_ARCH_NAME=${CUDA_ARCH_NAME} -DCUDA_ARCH_NAME=${CUDA_ARCH_NAME}
-DLITE_WITH_XPU=${LITE_WITH_XPU}
-DXPU_SDK_ROOT=${XPU_SDK_ROOT}
-DLITE_WITH_ARM=OFF) -DLITE_WITH_ARM=OFF)
ExternalProject_Add( ExternalProject_Add(
...@@ -83,7 +94,7 @@ message(STATUS "Paddle-lite SOURCE_DIR: ${LITE_SOURCE_DIR}") ...@@ -83,7 +94,7 @@ message(STATUS "Paddle-lite SOURCE_DIR: ${LITE_SOURCE_DIR}")
include_directories(${LITE_SOURCE_DIR}) include_directories(${LITE_SOURCE_DIR})
include_directories(${LITE_BINARY_DIR}) include_directories(${LITE_BINARY_DIR})
function(external_lite_static_libs alias path) function(external_lite_libs alias path)
add_library(${alias} SHARED IMPORTED GLOBAL) add_library(${alias} SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ${alias} PROPERTY IMPORTED_LOCATION SET_PROPERTY(TARGET ${alias} PROPERTY IMPORTED_LOCATION
${path}) ${path})
...@@ -92,7 +103,7 @@ function(external_lite_static_libs alias path) ...@@ -92,7 +103,7 @@ function(external_lite_static_libs alias path)
endif() endif()
endfunction() endfunction()
external_lite_static_libs(lite_full_static ${LITE_BINARY_DIR}/inference_lite_lib/cxx/lib/libpaddle_full_api_shared.so) external_lite_libs(lite_full_static ${LITE_BINARY_DIR}/inference_lite_lib/cxx/lib/libpaddle_full_api_shared.so)
set(LITE_SHARED_LIB ${LITE_BINARY_DIR}/inference_lite_lib/cxx/lib/libpaddle_full_api_shared.so) set(LITE_SHARED_LIB ${LITE_BINARY_DIR}/inference_lite_lib/cxx/lib/libpaddle_full_api_shared.so)
add_definitions(-DPADDLE_WITH_LITE) add_definitions(-DPADDLE_WITH_LITE)
......
...@@ -20,6 +20,8 @@ SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas) ...@@ -20,6 +20,8 @@ SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas)
SET(CBLAS_REPOSITORY https://github.com/xianyi/OpenBLAS.git) SET(CBLAS_REPOSITORY https://github.com/xianyi/OpenBLAS.git)
SET(CBLAS_TAG v0.3.7) SET(CBLAS_TAG v0.3.7)
IF(WITH_ARM) IF(WITH_ARM)
# Under the FT2000 architecture, the calculation result of blas.sgemm in openblas 0.3+ is wrong,
# so version 0.2 is used by default.
SET(CBLAS_TAG v0.2.18) SET(CBLAS_TAG v0.2.18)
ENDIF() ENDIF()
cache_third_party(extern_openblas cache_third_party(extern_openblas
......
...@@ -145,9 +145,9 @@ if (NOT "${PROTOBUF_ROOT}" STREQUAL "") ...@@ -145,9 +145,9 @@ if (NOT "${PROTOBUF_ROOT}" STREQUAL "")
find_program(PROTOBUF_PROTOC_EXECUTABLE protoc PATHS ${PROTOBUF_ROOT}/bin NO_DEFAULT_PATH) find_program(PROTOBUF_PROTOC_EXECUTABLE protoc PATHS ${PROTOBUF_ROOT}/bin NO_DEFAULT_PATH)
if (PROTOBUF_INCLUDE_DIR AND PROTOBUF_LIBRARY AND PROTOBUF_LITE_LIBRARY AND PROTOBUF_PROTOC_LIBRARY AND PROTOBUF_PROTOC_EXECUTABLE) if (PROTOBUF_INCLUDE_DIR AND PROTOBUF_LIBRARY AND PROTOBUF_LITE_LIBRARY AND PROTOBUF_PROTOC_LIBRARY AND PROTOBUF_PROTOC_EXECUTABLE)
SET(PROTOBUF_FOUND true) SET(PROTOBUF_FOUND true)
message(STATUS "Using custom protobuf library in ${PROTOBUF_ROOT}.")
SET_PROTOBUF_VERSION() SET_PROTOBUF_VERSION()
PROMPT_PROTOBUF_LIB() PROMPT_PROTOBUF_LIB()
message(STATUS "Using custom protobuf library in ${PROTOBUF_ROOT}.")
endif() endif()
endif() endif()
......
...@@ -8,6 +8,8 @@ function(CheckCompilerCXX11Flag) ...@@ -8,6 +8,8 @@ function(CheckCompilerCXX11Flag)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.") message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
elseif(${CMAKE_CXX_COMPILER_VERSION} VERSION_GREATER 8.2)
message(WARNING "Found GCC ${CMAKE_CXX_COMPILER_VERSION} which is too high, recommended to use GCC 8.2")
endif() endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang" # cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
......
...@@ -819,20 +819,18 @@ function(brpc_library TARGET_NAME) ...@@ -819,20 +819,18 @@ function(brpc_library TARGET_NAME)
cc_library("${TARGET_NAME}" SRCS "${brpc_library_SRCS}" DEPS "${TARGET_NAME}_proto" "${brpc_library_DEPS}") cc_library("${TARGET_NAME}" SRCS "${brpc_library_SRCS}" DEPS "${TARGET_NAME}_proto" "${brpc_library_DEPS}")
endfunction() endfunction()
# copy_if_different from src_file to dst_file before barrier_target. # copy_if_different from src_file to dst_file At the beginning of the build.
function(copy_if_different src_file dst_file barrier_target) function(copy_if_different src_file dst_file)
# this is a dummy target, should always be run to update ${dst_file} get_filename_component(FILE_NAME ${dst_file} NAME_WE)
add_custom_target(before_${barrier_target} ALL
DEPENDS before_${barrier_target}_custom_command
)
add_dependencies(${barrier_target} before_${barrier_target})
add_custom_command( # this is a dummy target for custom command, should always be run firstly to update ${dst_file}
OUTPUT before_${barrier_target}_custom_command add_custom_target(copy_${FILE_NAME}_command ALL
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${src_file} ${dst_file} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${src_file} ${dst_file}
COMMENT "copy_if_different ${dst_file}" COMMENT "copy_if_different ${dst_file}"
VERBATIM VERBATIM
) )
add_dependencies(extern_glog copy_${FILE_NAME}_command)
endfunction() endfunction()
# create a dummy source file, then create a static library. # create a dummy source file, then create a static library.
......
...@@ -19,9 +19,12 @@ set(FLUID_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_install_dir" CACHE STRING ...@@ -19,9 +19,12 @@ set(FLUID_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_install_dir" CACHE STRING
set(FLUID_INFERENCE_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_install_dir" CACHE STRING set(FLUID_INFERENCE_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_install_dir" CACHE STRING
"A path setting fluid inference shared and static libraries") "A path setting fluid inference shared and static libraries")
# TODO(zhaolong)
# At present, the size of static lib in Windows exceeds the system limit,
# so the generation of static lib is temporarily turned off.
if(WIN32) if(WIN32)
#todo: remove the option #todo: remove the option
option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON) option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." OFF)
if(NOT PYTHON_EXECUTABLE) if(NOT PYTHON_EXECUTABLE)
FIND_PACKAGE(PythonInterp REQUIRED) FIND_PACKAGE(PythonInterp REQUIRED)
endif() endif()
...@@ -187,21 +190,18 @@ copy(inference_lib_dist ...@@ -187,21 +190,18 @@ copy(inference_lib_dist
SRCS ${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io/crypto/cipher.h SRCS ${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io/crypto/cipher.h
DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include/crypto/) DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include/crypto/)
include_directories(${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io) include_directories(${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io)
# CAPI inference library for only inference # CAPI inference library for only inference
set(FLUID_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_c_install_dir" CACHE STRING set(FLUID_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_c_install_dir" CACHE STRING
"A path setting CAPI fluid inference shared") "A path setting CAPI fluid inference shared")
copy_part_of_thrid_party(inference_lib_dist ${FLUID_INFERENCE_C_INSTALL_DIR}) copy_part_of_thrid_party(inference_lib_dist ${FLUID_INFERENCE_C_INSTALL_DIR})
set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid") set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid")
if(WIN32) set(paddle_fluid_c_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/capi/libpaddle_fluid_c.*)
set(paddle_fluid_c_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/capi/${CMAKE_BUILD_TYPE}/paddle_fluid_c.*)
else(WIN32)
set(paddle_fluid_c_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/capi/libpaddle_fluid_c.*)
endif(WIN32)
copy(inference_lib_dist copy(inference_lib_dist
SRCS ${src_dir}/inference/capi/paddle_c_api.h ${paddle_fluid_c_lib} SRCS ${src_dir}/inference/capi/paddle_c_api.h ${paddle_fluid_c_lib}
DSTS ${FLUID_INFERENCE_C_INSTALL_DIR}/paddle/include ${FLUID_INFERENCE_C_INSTALL_DIR}/paddle/lib) DSTS ${FLUID_INFERENCE_C_INSTALL_DIR}/paddle/include ${FLUID_INFERENCE_C_INSTALL_DIR}/paddle/lib)
# fluid library for both train and inference # fluid library for both train and inference
set(fluid_lib_deps inference_lib_dist) set(fluid_lib_deps inference_lib_dist)
......
...@@ -7,14 +7,14 @@ if(WIN32) ...@@ -7,14 +7,14 @@ if(WIN32)
return() return()
endif() endif()
set(NCCL_ROOT "/usr" CACHE PATH "NCCL ROOT")
find_path(NCCL_INCLUDE_DIR nccl.h
PATHS ${NCCL_ROOT} ${NCCL_ROOT}/include ${NCCL_ROOT}/local/include
$ENV{NCCL_ROOT} $ENV{NCCL_ROOT}/include $ENV{NCCL_ROOT}/local/include
NO_DEFAULT_PATH
)
if(WITH_NCCL) if(WITH_NCCL)
set(NCCL_ROOT "/usr" CACHE PATH "NCCL ROOT")
find_path(NCCL_INCLUDE_DIR nccl.h
PATHS ${NCCL_ROOT} ${NCCL_ROOT}/include ${NCCL_ROOT}/local/include
$ENV{NCCL_ROOT} $ENV{NCCL_ROOT}/include $ENV{NCCL_ROOT}/local/include
NO_DEFAULT_PATH
)
file(READ ${NCCL_INCLUDE_DIR}/nccl.h NCCL_VERSION_FILE_CONTENTS) file(READ ${NCCL_INCLUDE_DIR}/nccl.h NCCL_VERSION_FILE_CONTENTS)
string(REGEX MATCH "define NCCL_VERSION_CODE +([0-9]+)" string(REGEX MATCH "define NCCL_VERSION_CODE +([0-9]+)"
......
...@@ -114,7 +114,7 @@ function(op_library TARGET) ...@@ -114,7 +114,7 @@ function(op_library TARGET)
endif() endif()
# Define operators that don't need pybind here. # Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_reduce_op" "compare_op" "logical_op" "nccl_op" foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
......
...@@ -63,7 +63,8 @@ class Array { ...@@ -63,7 +63,8 @@ class Array {
HOSTDEVICE inline const T &at(size_t i) const { HOSTDEVICE inline const T &at(size_t i) const {
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
PADDLE_ENFORCE_LT(i, N, "Array index out of bounds"); PADDLE_ENFORCE_LT(
i, N, platform::errors::OutOfRange("Array index out of bounds."));
#endif #endif
return (*this)[i]; return (*this)[i];
} }
...@@ -106,7 +107,7 @@ class Array<T, 0> { ...@@ -106,7 +107,7 @@ class Array<T, 0> {
static T obj(); static T obj();
return obj; return obj;
#else #else
PADDLE_THROW("Array<T, 0> has no element"); PADDLE_THROW(platform::errors::Unavailable("Array<T, 0> has no element."));
#endif #endif
} }
...@@ -115,7 +116,7 @@ class Array<T, 0> { ...@@ -115,7 +116,7 @@ class Array<T, 0> {
static const T obj(); static const T obj();
return obj; return obj;
#else #else
PADDLE_THROW("Array<T, 0> has no element"); PADDLE_THROW(platform::errors::Unavailable("Array<T, 0> has no element."));
#endif #endif
} }
......
...@@ -77,11 +77,13 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -77,11 +77,13 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for (auto var_name : fetch_var_names) { for (auto var_name : fetch_var_names) {
auto var_desc = block.FindVar(var_name); auto var_desc = block.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var_desc, platform::errors::NotFound("%s is not found.", var_name)); var_desc, platform::errors::NotFound(
"Variable %s is not found in main program.", var_name));
auto shapes = var_desc->GetShape(); auto shapes = var_desc->GetShape();
PADDLE_ENFORCE(shapes[shapes.size() - 1] == 1, PADDLE_ENFORCE_EQ(shapes[shapes.size() - 1], 1,
"var %s: Fetched var has wrong shape, " platform::errors::InvalidArgument(
"only variables with the last dimension size 1 supported", "Fetched variable %s has wrong shape, "
"only variables whose last dimension is 1 are supported",
var_name); var_name);
} }
...@@ -95,7 +97,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -95,7 +97,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
actual_thread_num_ = thread_num; actual_thread_num_ = thread_num;
int file_cnt = filelist.size(); int file_cnt = filelist.size();
PADDLE_ENFORCE_GT(file_cnt, 0, PADDLE_ENFORCE_GT(file_cnt, 0,
platform::errors::NotFound("Input file list is empty")); platform::errors::NotFound("Input file list is empty."));
if (actual_thread_num_ > file_cnt) { if (actual_thread_num_ > file_cnt) {
VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt
......
...@@ -72,7 +72,8 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) { ...@@ -72,7 +72,8 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
return val; return val;
} }
default: default:
PADDLE_THROW("Unsupport attr type %d", attr_desc.type()); PADDLE_THROW(platform::errors::Unavailable("Unsupport attribute type %d.",
attr_desc.type()));
} }
return boost::blank(); return boost::blank();
} }
......
...@@ -37,9 +37,10 @@ struct ExtractAttribute { ...@@ -37,9 +37,10 @@ struct ExtractAttribute {
try { try {
attr_value = &boost::get<T>(attr); attr_value = &boost::get<T>(attr);
} catch (boost::bad_get& bad_get) { } catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s", PADDLE_THROW(platform::errors::InvalidArgument(
attr_name_, paddle::platform::demangle(typeid(T).name()), "Cannot get attribute (%s) by type %s, its type is %s.", attr_name_,
paddle::platform::demangle(attr.type().name())); paddle::platform::demangle(typeid(T).name()),
paddle::platform::demangle(attr.type().name())));
} }
return attr_value; return attr_value;
} }
...@@ -70,8 +71,9 @@ struct ExtractAttribute<bool> { ...@@ -70,8 +71,9 @@ struct ExtractAttribute<bool> {
try { try {
attr_value = &boost::get<bool>(attr); attr_value = &boost::get<bool>(attr);
} catch (boost::bad_get& bad_get) { } catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s", PADDLE_THROW(platform::errors::InvalidArgument(
attr_name_, paddle::platform::demangle(attr.type().name())); "Cannot get attribute (%s) by type bool, its type is %s.", attr_name_,
paddle::platform::demangle(attr.type().name())));
} }
return attr_value; return attr_value;
} }
...@@ -96,8 +98,9 @@ struct ExtractAttribute<int64_t> { ...@@ -96,8 +98,9 @@ struct ExtractAttribute<int64_t> {
try { try {
attr_value = &boost::get<int64_t>(attr); attr_value = &boost::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) { } catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s", PADDLE_THROW(platform::errors::InvalidArgument(
attr_name_, paddle::platform::demangle(attr.type().name())); "Cannot get attribute (%s) by type int64_t, its type is %s.",
attr_name_, paddle::platform::demangle(attr.type().name())));
} }
return attr_value; return attr_value;
} }
...@@ -124,8 +127,10 @@ struct ExtractAttribute<std::vector<int64_t>> { ...@@ -124,8 +127,10 @@ struct ExtractAttribute<std::vector<int64_t>> {
try { try {
attr_value = &boost::get<std::vector<int64_t>>(attr); attr_value = &boost::get<std::vector<int64_t>>(attr);
} catch (boost::bad_get& bad_get) { } catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s", PADDLE_THROW(platform::errors::InvalidArgument(
attr_name_, paddle::platform::demangle(attr.type().name())); "Cannot get attribute (%s) by type std::vector<int64_t>, its type is "
"%s.",
attr_name_, paddle::platform::demangle(attr.type().name())));
} }
return attr_value; return attr_value;
} }
...@@ -150,8 +155,9 @@ struct ExtractAttribute<float> { ...@@ -150,8 +155,9 @@ struct ExtractAttribute<float> {
try { try {
attr_value = &boost::get<float>(attr); attr_value = &boost::get<float>(attr);
} catch (boost::bad_get& bad_get) { } catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type float, its type is %s", PADDLE_THROW(platform::errors::InvalidArgument(
attr_name_, paddle::platform::demangle(attr.type().name())); "Cannot get attribute (%s) by type float, its type is %s.",
attr_name_, paddle::platform::demangle(attr.type().name())));
} }
return attr_value; return attr_value;
} }
...@@ -173,8 +179,9 @@ class AttrReader { ...@@ -173,8 +179,9 @@ class AttrReader {
template <typename T> template <typename T>
inline const T& Get(const std::string& name) const { inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", PADDLE_ENFORCE_NE(attrs_.count(name), 0,
name); platform::errors::NotFound(
"Attribute (%s) should be in AttributeMap.", name));
Attribute& attr = const_cast<Attribute&>(attrs_.at(name)); Attribute& attr = const_cast<Attribute&>(attrs_.at(name));
ExtractAttribute<T> extract_attr(name); ExtractAttribute<T> extract_attr(name);
...@@ -192,8 +199,10 @@ class GreaterThanChecker { ...@@ -192,8 +199,10 @@ class GreaterThanChecker {
public: public:
explicit GreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} explicit GreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(const T& value) const { void operator()(const T& value) const {
PADDLE_ENFORCE_GT(value, lower_bound_, PADDLE_ENFORCE_GT(
platform::errors::OutOfRange("larger_than check fails.")); value, lower_bound_,
platform::errors::OutOfRange(
"Check for attribute value greater than a certain value failed."));
} }
private: private:
...@@ -205,7 +214,10 @@ class EqualGreaterThanChecker { ...@@ -205,7 +214,10 @@ class EqualGreaterThanChecker {
public: public:
explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(const T& value) const { void operator()(const T& value) const {
PADDLE_ENFORCE_GE(value, lower_bound_, "equal_larger_than check fails."); PADDLE_ENFORCE_GE(
value, lower_bound_,
platform::errors::OutOfRange("Check for attribute valur equal or "
"greater than a certain value failed."));
} }
private: private:
...@@ -231,9 +243,10 @@ class EnumInContainer { ...@@ -231,9 +243,10 @@ class EnumInContainer {
public: public:
explicit EnumInContainer(const std::unordered_set<T>& c) : container_(c) {} explicit EnumInContainer(const std::unordered_set<T>& c) : container_(c) {}
void operator()(const T& val) const { void operator()(const T& val) const {
PADDLE_ENFORCE(container_.find(val) != container_.end(), PADDLE_ENFORCE_NE(
"Value %s is not in enum container %s", val, container_.find(val), container_.end(),
ContainerDebugString()); platform::errors::NotFound("Value %s is not in enum container %s.", val,
ContainerDebugString()));
} }
private: private:
...@@ -284,8 +297,11 @@ class TypedAttrChecker { ...@@ -284,8 +297,11 @@ class TypedAttrChecker {
// we can add more common limits, like LessThan(), Between()... // we can add more common limits, like LessThan(), Between()...
TypedAttrChecker& SetDefault(const T& default_value) { TypedAttrChecker& SetDefault(const T& default_value) {
PADDLE_ENFORCE(default_value_setter_.empty(), PADDLE_ENFORCE_EQ(
"%s can't have more than one default value!", attr_name_); default_value_setter_.empty(), true,
platform::errors::AlreadyExists(
"Attribute (%s) has a default value and cannot be set repeatedly.",
attr_name_));
default_value_setter_.push_back(DefaultValueSetter<T>(default_value)); default_value_setter_.push_back(DefaultValueSetter<T>(default_value));
return *this; return *this;
} }
...@@ -308,8 +324,10 @@ class TypedAttrChecker { ...@@ -308,8 +324,10 @@ class TypedAttrChecker {
auto it = attr_map->find(attr_name_); auto it = attr_map->find(attr_name_);
if (it == attr_map->end()) { if (it == attr_map->end()) {
// user do not set this attr // user do not set this attr
PADDLE_ENFORCE(!default_value_setter_.empty(), PADDLE_ENFORCE_EQ(
"Attribute '%s' is required!", attr_name_); default_value_setter_.empty(), false,
platform::errors::InvalidArgument(
"Attribute (%s) is not set correctly.", attr_name_));
// default_value_setter_ has no more than one element // default_value_setter_ has no more than one element
attr_map->emplace(attr_name_, default_value_setter_[0]()); attr_map->emplace(attr_name_, default_value_setter_[0]());
} }
......
...@@ -23,11 +23,14 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place, ...@@ -23,11 +23,14 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
in.place().which(), dst_place.which(), in.place().which(), dst_place.which(),
"Currently, model parallelism is only supported between CPU and CUDA"); platform::errors::Unavailable("Currently, model parallelism is only "
"supported between CPU and CUDA."));
// NOTE(yy): TransDataDevice should wait for computation of input. // NOTE(yy): TransDataDevice should wait for computation of input.
platform::DeviceContextPool::Instance().Get(in.place())->Wait(); if (!platform::is_cuda_pinned_place(in.place())) {
platform::DeviceContextPool::Instance().Get(dst_place)->Wait(); platform::DeviceContextPool::Instance().Get(in.place())->Wait();
platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
}
// FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and // FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and
// the enforced checkings have been done in GetDeviceContext, so the // the enforced checkings have been done in GetDeviceContext, so the
......
...@@ -133,11 +133,14 @@ bool DataFeed::PickOneFile(std::string* filename) { ...@@ -133,11 +133,14 @@ bool DataFeed::PickOneFile(std::string* filename) {
} }
void DataFeed::CheckInit() { void DataFeed::CheckInit() {
PADDLE_ENFORCE(finish_init_, "Initialization did not succeed."); PADDLE_ENFORCE_EQ(finish_init_, true, platform::errors::PreconditionNotMet(
"DataFeed initialization failed."));
} }
void DataFeed::CheckSetFileList() { void DataFeed::CheckSetFileList() {
PADDLE_ENFORCE(finish_set_filelist_, "Set filelist did not succeed."); PADDLE_ENFORCE_EQ(
finish_set_filelist_, true,
platform::errors::PreconditionNotMet("DataFeed set filelist failed."));
} }
void DataFeed::CheckStart() { void DataFeed::CheckStart() {
...@@ -160,14 +163,18 @@ void DataFeed::CopyToFeedTensor(void* dst, const void* src, size_t size) { ...@@ -160,14 +163,18 @@ void DataFeed::CopyToFeedTensor(void* dst, const void* src, size_t size) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice); cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice);
#else #else
PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option"); PADDLE_THROW(platform::errors::Unimplemented(
"Not supported GPU, please compile with option WITH_GPU=ON."));
#endif #endif
} }
} }
template <typename T> template <typename T>
void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) { void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
PADDLE_ENFORCE(queue_size > 0, "Illegal queue size: %d.", queue_size); PADDLE_ENFORCE_GT(
queue_size, 0,
platform::errors::InvalidArgument(
"Queue size %d is illegal in PrivateQueueDataFeed.", queue_size));
queue_size_ = queue_size; queue_size_ = queue_size;
queue_ = paddle::framework::MakeChannel<T>(); queue_ = paddle::framework::MakeChannel<T>();
queue_->SetCapacity(queue_size); queue_->SetCapacity(queue_size);
...@@ -418,8 +425,10 @@ void MultiSlotDataFeed::Init( ...@@ -418,8 +425,10 @@ void MultiSlotDataFeed::Init(
finish_set_filelist_ = false; finish_set_filelist_ = false;
finish_start_ = false; finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(), PADDLE_ENFORCE_EQ(
"Multi_slot_desc has not been set."); data_feed_desc.has_multi_slot_desc(), true,
platform::errors::PreconditionNotMet(
"Multi_slot_desc has not been set in MultiSlotDataFeed."));
paddle::framework::MultiSlotDesc multi_slot_desc = paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc(); data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size()); SetBatchSize(data_feed_desc.batch_size());
...@@ -668,13 +677,14 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) { ...@@ -668,13 +677,14 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
for (size_t i = 0; i < use_slots_index_.size(); ++i) { for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i]; int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10); int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE( PADDLE_ENFORCE_NE(
num, num, 0,
"The number of ids can not be zero, you need padding " platform::errors::InvalidArgument(
"it in data generator; or if there is something wrong with " "The number of ids can not be zero, you need padding "
"the data, please check if the data contains unresolvable " "it in data generator; or if there is something wrong with "
"characters.\nplease check this error line: %s", "the data, please check if the data contains unresolvable "
str); "characters.\nplease check this error line: %s.",
str));
if (idx != -1) { if (idx != -1) {
(*instance)[idx].Init(all_slots_type_[i]); (*instance)[idx].Init(all_slots_type_[i]);
...@@ -765,8 +775,10 @@ void MultiSlotInMemoryDataFeed::Init( ...@@ -765,8 +775,10 @@ void MultiSlotInMemoryDataFeed::Init(
finish_set_filelist_ = false; finish_set_filelist_ = false;
finish_start_ = false; finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(), PADDLE_ENFORCE_EQ(
"Multi_slot_desc has not been set."); data_feed_desc.has_multi_slot_desc(), true,
platform::errors::PreconditionNotMet(
"Multi_slot_desc has not been set in MultiSlotInMemoryDataFeed."));
paddle::framework::MultiSlotDesc multi_slot_desc = paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc(); data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size()); SetBatchSize(data_feed_desc.batch_size());
...@@ -898,13 +910,14 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { ...@@ -898,13 +910,14 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
for (size_t i = 0; i < use_slots_index_.size(); ++i) { for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i]; int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10); int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE( PADDLE_ENFORCE_NE(
num, num, 0,
"The number of ids can not be zero, you need padding " platform::errors::InvalidArgument(
"it in data generator; or if there is something wrong with " "The number of ids can not be zero, you need padding "
"the data, please check if the data contains unresolvable " "it in data generator; or if there is something wrong with "
"characters.\nplease check this error line: %s", "the data, please check if the data contains unresolvable "
str); "characters.\nplease check this error line: %s.",
str));
if (idx != -1) { if (idx != -1) {
if (all_slots_type_[i][0] == 'f') { // float if (all_slots_type_[i][0] == 'f') { // float
for (int j = 0; j < num; ++j) { for (int j = 0; j < num; ++j) {
...@@ -963,13 +976,14 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) { ...@@ -963,13 +976,14 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) {
for (size_t i = 0; i < use_slots_index_.size(); ++i) { for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i]; int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10); int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE( PADDLE_ENFORCE_NE(
num, num, 0,
"The number of ids can not be zero, you need padding " platform::errors::InvalidArgument(
"it in data generator; or if there is something wrong with " "The number of ids can not be zero, you need padding "
"the data, please check if the data contains unresolvable " "it in data generator; or if there is something wrong with "
"characters.\nplease check this error line: %s", "the data, please check if the data contains unresolvable "
str); "characters.\nplease check this error line: %s.",
str));
if (idx != -1) { if (idx != -1) {
if (all_slots_type_[i][0] == 'f') { // float if (all_slots_type_[i][0] == 'f') { // float
...@@ -1085,7 +1099,7 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec( ...@@ -1085,7 +1099,7 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
PADDLE_ENFORCE_EQ(slot_offset.size(), 2, PADDLE_ENFORCE_EQ(slot_offset.size(), 2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"In batch reader, the sparse tensor lod size " "In batch reader, the sparse tensor lod size "
"must be 2, but received %d", "must be 2, but received %d.",
slot_offset.size())); slot_offset.size()));
const auto& max_size = slot_offset[1]; const auto& max_size = slot_offset[1];
tmp_offset.reserve(max_size + 1); tmp_offset.reserve(max_size + 1);
...@@ -1137,10 +1151,13 @@ void PrivateInstantDataFeed<T>::PutToFeedVec() { ...@@ -1137,10 +1151,13 @@ void PrivateInstantDataFeed<T>::PutToFeedVec() {
for (const auto e : use_slots_shape_[i]) { for (const auto e : use_slots_shape_[i]) {
total_dims *= e; total_dims *= e;
} }
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
total_dims == total_instance, total_dims, total_instance,
"The actual data size of slot[%s] doesn't match its declaration", platform::errors::InvalidArgument(
use_slots_[i].c_str()); "The actual data size of slot[%s] doesn't match its declaration. "
"The actual data size of slot is %lld"
", and its declaration is %lld.",
use_slots_[i].c_str(), total_dims, total_instance));
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i])); feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
} }
} }
...@@ -1162,7 +1179,9 @@ int PrivateInstantDataFeed<T>::Next() { ...@@ -1162,7 +1179,9 @@ int PrivateInstantDataFeed<T>::Next() {
return -1; return -1;
} }
PADDLE_ENFORCE(true == ParseOneMiniBatch(), "Fail to parse mini-batch data"); PADDLE_ENFORCE_EQ(
true, ParseOneMiniBatch(),
platform::errors::InvalidArgument("Fail to parse mini-batch data."));
PutToFeedVec(); PutToFeedVec();
return ins_vec_[0].GetBatchSize(); return ins_vec_[0].GetBatchSize();
} }
...@@ -1173,8 +1192,10 @@ void PrivateInstantDataFeed<T>::Init(const DataFeedDesc& data_feed_desc) { ...@@ -1173,8 +1192,10 @@ void PrivateInstantDataFeed<T>::Init(const DataFeedDesc& data_feed_desc) {
finish_set_filelist_ = false; finish_set_filelist_ = false;
finish_start_ = false; finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(), PADDLE_ENFORCE_EQ(
"Multi_slot_desc has not been set."); data_feed_desc.has_multi_slot_desc(), true,
platform::errors::PreconditionNotMet(
"Multi_slot_desc has not been set in PrivateInstantDataFeed."));
paddle::framework::MultiSlotDesc multi_slot_desc = paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc(); data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size()); SetBatchSize(data_feed_desc.batch_size());
...@@ -1217,7 +1238,10 @@ template class PrivateInstantDataFeed<std::vector<MultiSlotType>>; ...@@ -1217,7 +1238,10 @@ template class PrivateInstantDataFeed<std::vector<MultiSlotType>>;
bool MultiSlotFileInstantDataFeed::Preprocess(const std::string& filename) { bool MultiSlotFileInstantDataFeed::Preprocess(const std::string& filename) {
fd_ = open(filename.c_str(), O_RDONLY); fd_ = open(filename.c_str(), O_RDONLY);
PADDLE_ENFORCE(fd_ != -1, "Fail to open file: %s", filename.c_str()); PADDLE_ENFORCE_NE(
fd_, -1, platform::errors::Unavailable(
"Fail to open file: %s in MultiSlotFileInstantDataFeed.",
filename.c_str()));
struct stat sb; struct stat sb;
fstat(fd_, &sb); fstat(fd_, &sb);
...@@ -1225,7 +1249,11 @@ bool MultiSlotFileInstantDataFeed::Preprocess(const std::string& filename) { ...@@ -1225,7 +1249,11 @@ bool MultiSlotFileInstantDataFeed::Preprocess(const std::string& filename) {
buffer_ = buffer_ =
reinterpret_cast<char*>(mmap(NULL, end_, PROT_READ, MAP_PRIVATE, fd_, 0)); reinterpret_cast<char*>(mmap(NULL, end_, PROT_READ, MAP_PRIVATE, fd_, 0));
PADDLE_ENFORCE(buffer_ != MAP_FAILED, strerror(errno)); PADDLE_ENFORCE_NE(
buffer_, MAP_FAILED,
platform::errors::Unavailable(
"Memory map failed when create shared memory, error number is %s.",
strerror(errno)));
offset_ = 0; offset_ = 0;
return true; return true;
...@@ -1257,12 +1285,13 @@ bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() { ...@@ -1257,12 +1285,13 @@ bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() {
char type = all_slots_type_[i][0]; char type = all_slots_type_[i][0];
uint16_t num = *reinterpret_cast<uint16_t*>(buffer_ + offset_); uint16_t num = *reinterpret_cast<uint16_t*>(buffer_ + offset_);
PADDLE_ENFORCE( PADDLE_ENFORCE_NE(
num, num, 0,
"The number of ids can not be zero, you need padding " platform::errors::InvalidArgument(
"it in data generator; or if there is something wrong with " "The number of ids can not be zero, you need padding "
"the data, please check if the data contains unresolvable " "it in data generator; or if there is something wrong with "
"characters."); "the data, please check if the data contains unresolvable "
"characters."));
offset_ += sizeof(uint16_t); offset_ += sizeof(uint16_t);
if (idx != -1) { if (idx != -1) {
...@@ -1304,7 +1333,12 @@ bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() { ...@@ -1304,7 +1333,12 @@ bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() {
} }
PADDLE_ENFORCE(batch_size_ == default_batch_size_ || offset_ == end_, PADDLE_ENFORCE(batch_size_ == default_batch_size_ || offset_ == end_,
"offset_ != end_"); platform::errors::InvalidArgument(
"The batch size id not equal to default batch size, or "
"the offset is not equal to end index."
"The batch size is %d, default batcch size is %d, offset "
"is %d, end index is %d.",
batch_size_, default_batch_size_, offset_, end_));
return true; return true;
} }
#endif #endif
......
...@@ -116,7 +116,8 @@ class DataFeed { ...@@ -116,7 +116,8 @@ class DataFeed {
virtual ~DataFeed() {} virtual ~DataFeed() {}
virtual void Init(const DataFeedDesc& data_feed_desc) = 0; virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
virtual bool CheckFile(const char* filename) { virtual bool CheckFile(const char* filename) {
PADDLE_THROW("This function(CheckFile) is not implemented."); PADDLE_THROW(platform::errors::Unimplemented(
"This function(CheckFile) is not implemented."));
} }
// Set filelist for DataFeed. // Set filelist for DataFeed.
// Pay attention that it must init all readers before call this function. // Pay attention that it must init all readers before call this function.
...@@ -179,7 +180,8 @@ class DataFeed { ...@@ -179,7 +180,8 @@ class DataFeed {
} }
virtual int GetCurBatchSize() { return batch_size_; } virtual int GetCurBatchSize() { return batch_size_; }
virtual void LoadIntoMemory() { virtual void LoadIntoMemory() {
PADDLE_THROW("This function(LoadIntoMemory) is not implemented."); PADDLE_THROW(platform::errors::Unimplemented(
"This function(LoadIntoMemory) is not implemented."));
} }
virtual void SetPlace(const paddle::platform::Place& place) { virtual void SetPlace(const paddle::platform::Place& place) {
place_ = place; place_ = place;
...@@ -438,14 +440,23 @@ class MultiSlotType { ...@@ -438,14 +440,23 @@ class MultiSlotType {
private: private:
void CheckType(const std::string& type) const { void CheckType(const std::string& type) const {
PADDLE_ENFORCE((type == "uint64") || (type == "float"), PADDLE_ENFORCE_EQ((type == "uint64" || type == "float"), true,
"There is no this type<%s>.", type); platform::errors::InvalidArgument(
"MultiSlotType error, expect type is uint64 or "
"float, but received type is %s.",
type));
} }
void CheckFloat() const { void CheckFloat() const {
PADDLE_ENFORCE(type_[0] == 'f', "Add %s value to float slot.", type_); PADDLE_ENFORCE_EQ(
type_[0], 'f',
platform::errors::InvalidArgument(
"MultiSlotType error, add %s value to float slot.", type_));
} }
void CheckUint64() const { void CheckUint64() const {
PADDLE_ENFORCE(type_[0] == 'u', "Add %s value to uint64 slot.", type_); PADDLE_ENFORCE_EQ(
type_[0], 'u',
platform::errors::InvalidArgument(
"MultiSlotType error, add %s value to uint64 slot.", type_));
} }
std::vector<float> float_feasign_; std::vector<float> float_feasign_;
std::vector<uint64_t> uint64_feasign_; std::vector<uint64_t> uint64_feasign_;
......
...@@ -34,8 +34,10 @@ paddle::framework::DataFeedDesc load_datafeed_param_from_file( ...@@ -34,8 +34,10 @@ paddle::framework::DataFeedDesc load_datafeed_param_from_file(
const char* filename) { const char* filename) {
paddle::framework::DataFeedDesc data_feed_desc; paddle::framework::DataFeedDesc data_feed_desc;
int file_descriptor = open(filename, O_RDONLY); int file_descriptor = open(filename, O_RDONLY);
PADDLE_ENFORCE_NE(file_descriptor, -1, platform::errors::Unavailable( PADDLE_ENFORCE_NE(
"Cannot open file %s.", filename)); file_descriptor, -1,
platform::errors::Unavailable(
"Cannot open file %s c load datafeed param from file.", filename));
google::protobuf::io::FileInputStream fileInput(file_descriptor); google::protobuf::io::FileInputStream fileInput(file_descriptor);
google::protobuf::TextFormat::Parse(&fileInput, &data_feed_desc); google::protobuf::TextFormat::Parse(&fileInput, &data_feed_desc);
close(file_descriptor); close(file_descriptor);
...@@ -45,8 +47,10 @@ paddle::framework::DataFeedDesc load_datafeed_param_from_file( ...@@ -45,8 +47,10 @@ paddle::framework::DataFeedDesc load_datafeed_param_from_file(
const std::vector<std::string> load_filelist_from_file(const char* filename) { const std::vector<std::string> load_filelist_from_file(const char* filename) {
std::vector<std::string> filelist; std::vector<std::string> filelist;
std::ifstream fin(filename); std::ifstream fin(filename);
PADDLE_ENFORCE_EQ(fin.good(), true, platform::errors::Unavailable( PADDLE_ENFORCE_EQ(
"Cannot open file %s.", filename)); fin.good(), true,
platform::errors::Unavailable(
"Cannot open file %s when load filelist from file.", filename));
std::string line; std::string line;
while (getline(fin, line)) { while (getline(fin, line)) {
filelist.push_back(line); filelist.push_back(line);
...@@ -196,7 +200,8 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set, ...@@ -196,7 +200,8 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
} }
} }
} else { } else {
PADDLE_THROW("Error type in proto file."); PADDLE_THROW(platform::errors::InvalidArgument(
"Error type in proto file."));
} }
} else { // sparse branch } else { // sparse branch
if (slot.type() == "uint64") { if (slot.type() == "uint64") {
...@@ -218,7 +223,8 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set, ...@@ -218,7 +223,8 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
} }
} }
} else { } else {
PADDLE_THROW("Error type in proto file."); PADDLE_THROW(platform::errors::InvalidArgument(
"Error type in proto file."));
} }
} // end sparse branch } // end sparse branch
++index; ++index;
...@@ -272,7 +278,10 @@ void GetElemSetFromFile(std::vector<MultiTypeSet>* file_elem_set, ...@@ -272,7 +278,10 @@ void GetElemSetFromFile(std::vector<MultiTypeSet>* file_elem_set,
file_elem_set->resize(used_slot_num); file_elem_set->resize(used_slot_num);
for (const auto& file : filelist) { for (const auto& file : filelist) {
std::ifstream fin(file.c_str()); std::ifstream fin(file.c_str());
PADDLE_ENFORCE(fin.good(), "Can not open %s.", file.c_str()); PADDLE_ENFORCE_EQ(
fin.good(), true,
platform::errors::Unavailable(
"Can not open %s when get element set from file.", file.c_str()));
while (1) { while (1) {
bool end_flag = false; bool end_flag = false;
int index = 0; int index = 0;
...@@ -298,7 +307,8 @@ void GetElemSetFromFile(std::vector<MultiTypeSet>* file_elem_set, ...@@ -298,7 +307,8 @@ void GetElemSetFromFile(std::vector<MultiTypeSet>* file_elem_set,
} }
} }
} else { } else {
PADDLE_THROW("Error type in proto file."); PADDLE_THROW(
platform::errors::InvalidArgument("Error type in proto file."));
} }
if (slot.is_used()) { if (slot.is_used()) {
++index; ++index;
......
...@@ -45,7 +45,8 @@ inline DataLayout StringToDataLayout(const std::string& str) { ...@@ -45,7 +45,8 @@ inline DataLayout StringToDataLayout(const std::string& str) {
} else if (s == "MKLDNNLAYOUT") { } else if (s == "MKLDNNLAYOUT") {
return DataLayout::kMKLDNN; return DataLayout::kMKLDNN;
} else { } else {
PADDLE_THROW("Unknown storage order string: %s", s); PADDLE_THROW(platform::errors::InvalidArgument(
"Unknown data layout type string: %s.", s));
} }
} }
...@@ -60,7 +61,8 @@ inline std::string DataLayoutToString(const DataLayout& data_layout) { ...@@ -60,7 +61,8 @@ inline std::string DataLayoutToString(const DataLayout& data_layout) {
case DataLayout::kMKLDNN: case DataLayout::kMKLDNN:
return "MKLDNNLAYOUT"; return "MKLDNNLAYOUT";
default: default:
PADDLE_THROW("unknown DataLayout %d", data_layout); PADDLE_THROW(platform::errors::InvalidArgument(
"Unknown Data Layout type %d.", data_layout));
} }
} }
......
...@@ -25,14 +25,17 @@ namespace paddle { ...@@ -25,14 +25,17 @@ namespace paddle {
namespace framework { namespace framework {
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to) { std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to) {
PADDLE_ENFORCE_NE(from, to, PADDLE_ENFORCE_NE(
"layout transform should transform different layout"); from, to,
platform::errors::InvalidArgument(
"Layout transform should transform between different layout."));
if (from == DataLayout::kNCHW && to == DataLayout::kNHWC) { if (from == DataLayout::kNCHW && to == DataLayout::kNHWC) {
return {0, 2, 3, 1}; return {0, 2, 3, 1};
} else if (from == DataLayout::kNHWC && to == DataLayout::kNCHW) { } else if (from == DataLayout::kNHWC && to == DataLayout::kNCHW) {
return {0, 3, 1, 2}; return {0, 3, 1, 2};
} else { } else {
PADDLE_THROW("unsupported transform"); PADDLE_THROW(
platform::errors::InvalidArgument("Unsupported layout transform."));
} }
} }
...@@ -55,7 +58,8 @@ struct CastDataLayout { ...@@ -55,7 +58,8 @@ struct CastDataLayout {
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_); auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
trans4(*context, in_, out_, axis_); trans4(*context, in_, out_, axis_);
} else { } else {
PADDLE_THROW("Unsupport CPU <-> GPU!"); PADDLE_THROW(platform::errors::PreconditionNotMet(
"Unsupported data layout cast from CPU to GPU."));
} }
} }
}; };
...@@ -66,9 +70,14 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var, ...@@ -66,9 +70,14 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
PADDLE_ENFORCE( PADDLE_ENFORCE(
platform::places_are_same_class(kernel_type_for_var.place_, platform::places_are_same_class(kernel_type_for_var.place_,
expected_kernel_type.place_), expected_kernel_type.place_),
"TransDataLayout only support DataLayout transform on same place!"); platform::errors::PreconditionNotMet(
"TransDataLayout only support DataLayout transform on same place."));
PADDLE_ENFORCE(arity(in.dims()) == 4, "Input Arity only support 4!"); PADDLE_ENFORCE_EQ(
arity(in.dims()), 4,
platform::errors::InvalidArgument(
"Input dimension arity only can be 4, the input dimension is %s.",
in.dims()));
auto& pool = platform::DeviceContextPool::Instance(); auto& pool = platform::DeviceContextPool::Instance();
...@@ -108,7 +117,8 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) { ...@@ -108,7 +117,8 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
case mkldnn::memory::data_type::s32: case mkldnn::memory::data_type::s32:
return platform::to_void_cast(tensor.data<int32_t>()); return platform::to_void_cast(tensor.data<int32_t>());
default: default:
PADDLE_THROW("wrong mkldnn type provided"); PADDLE_THROW(
platform::errors::InvalidArgument("Wrong mkldnn type provided."));
} }
} }
...@@ -121,8 +131,9 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -121,8 +131,9 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
PADDLE_ENFORCE( PADDLE_ENFORCE(
in_layout == DataLayout::kMKLDNN && out_layout != DataLayout::kMKLDNN, in_layout == DataLayout::kMKLDNN && out_layout != DataLayout::kMKLDNN,
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to " platform::errors::InvalidArgument(
"non-MKLDNN"); "TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"non-MKLDNN"));
innerTransDataLayoutFromMKLDNN( innerTransDataLayoutFromMKLDNN(
in_layout, in_layout,
...@@ -155,7 +166,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, ...@@ -155,7 +166,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
memory::data_type in_type = ToMKLDNNDataType(in.type()); memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE_NE(in_type, memory::data_type::undef, PADDLE_ENFORCE_NE(in_type, memory::data_type::undef,
"Input tensor type is not supported: %s", in.type()); platform::errors::InvalidArgument(
"Input tensor type (%s) is not supported.",
DataTypeToString(in.type())));
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format()); auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
auto out_format = auto out_format =
......
...@@ -38,8 +38,9 @@ inline MKLDNNMemoryFormat ToMKLDNNFormat(const DataLayout& layout) { ...@@ -38,8 +38,9 @@ inline MKLDNNMemoryFormat ToMKLDNNFormat(const DataLayout& layout) {
case DataLayout::kNCHW: case DataLayout::kNCHW:
return MKLDNNMemoryFormat::nchw; return MKLDNNMemoryFormat::nchw;
default: default:
PADDLE_THROW("Fail to convert layout %s to MKLDNN format", PADDLE_THROW(platform::errors::InvalidArgument(
DataLayoutToString(layout)); "Fail to convert layout %s to MKLDNN format.",
DataLayoutToString(layout)));
} }
} }
...@@ -50,7 +51,8 @@ inline DataLayout ToPaddleLayout(const MKLDNNMemoryFormat& format) { ...@@ -50,7 +51,8 @@ inline DataLayout ToPaddleLayout(const MKLDNNMemoryFormat& format) {
case MKLDNNMemoryFormat::nchw: case MKLDNNMemoryFormat::nchw:
return DataLayout::kNCHW; return DataLayout::kNCHW;
default: default:
PADDLE_THROW("Fail to convert MKLDNN format to paddle layout"); PADDLE_THROW(platform::errors::InvalidArgument(
"Fail to convert MKLDNN format to paddle layout."));
} }
} }
......
...@@ -45,9 +45,10 @@ void TransformData(const OpKernelType &expected_kernel_type, ...@@ -45,9 +45,10 @@ void TransformData(const OpKernelType &expected_kernel_type,
if (NeedTransformLayout(lout, lin)) { if (NeedTransformLayout(lout, lin)) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (lin == DataLayout::kMKLDNN || lout == DataLayout::kMKLDNN) { if (lin == DataLayout::kMKLDNN || lout == DataLayout::kMKLDNN) {
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
!(lin == DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN), !(lin == DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN), true,
"No layout transform needed between two MKLDNN OPKernels"); platform::errors::PreconditionNotMet(
"No layout transform needed between two MKLDNN OPKernels."));
if (lin != DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN) { if (lin != DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN) {
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel // Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
...@@ -96,7 +97,10 @@ void TransformData(const OpKernelType &expected_kernel_type, ...@@ -96,7 +97,10 @@ void TransformData(const OpKernelType &expected_kernel_type,
PassTensorData(&out, &in); PassTensorData(&out, &in);
} }
PADDLE_ENFORCE(transformed, "No transform is applied, please check!"); PADDLE_ENFORCE_EQ(
transformed, true,
platform::errors::PreconditionNotMet(
"No transform is applied for the data needs to be transformed."));
// get output data // get output data
output_tensor->ShareDataWith(in); output_tensor->ShareDataWith(in);
} }
...@@ -116,7 +120,10 @@ void SetTensorToVariable(const Variable &in_var, const Tensor &tensor, ...@@ -116,7 +120,10 @@ void SetTensorToVariable(const Variable &in_var, const Tensor &tensor,
trans_selected_rows->set_rows(in_selected_rows.rows()); trans_selected_rows->set_rows(in_selected_rows.rows());
trans_selected_rows->mutable_value()->ShareDataWith(tensor); trans_selected_rows->mutable_value()->ShareDataWith(tensor);
} else { } else {
PADDLE_THROW("unknown var type"); PADDLE_THROW(platform::errors::Unavailable(
"Unsupported variable type, only supports LoDTensor or SelectedRows, "
"but the input variable type is %s.",
ToTypeName(in_var.Type())));
} }
} }
......
...@@ -65,7 +65,8 @@ proto::VarType::Type ToDataType(std::type_index type) { ...@@ -65,7 +65,8 @@ proto::VarType::Type ToDataType(std::type_index type) {
if (it != gDataTypeMap().cpp_to_proto_.end()) { if (it != gDataTypeMap().cpp_to_proto_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support %s as tensor type", type.name()); PADDLE_THROW(platform::errors::Unimplemented(
"Not support %s as tensor data type.", platform::demangle(type.name())));
} }
std::type_index ToTypeIndex(proto::VarType::Type type) { std::type_index ToTypeIndex(proto::VarType::Type type) {
...@@ -73,8 +74,9 @@ std::type_index ToTypeIndex(proto::VarType::Type type) { ...@@ -73,8 +74,9 @@ std::type_index ToTypeIndex(proto::VarType::Type type) {
if (it != gDataTypeMap().proto_to_cpp_.end()) { if (it != gDataTypeMap().proto_to_cpp_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type", PADDLE_THROW(platform::errors::Unimplemented(
static_cast<int>(type)); "Not support proto::VarType::Type(%d) as tensor type.",
static_cast<int>(type)));
} }
std::string DataTypeToString(const proto::VarType::Type type) { std::string DataTypeToString(const proto::VarType::Type type) {
...@@ -82,8 +84,9 @@ std::string DataTypeToString(const proto::VarType::Type type) { ...@@ -82,8 +84,9 @@ std::string DataTypeToString(const proto::VarType::Type type) {
if (it != gDataTypeMap().proto_to_str_.end()) { if (it != gDataTypeMap().proto_to_str_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type", PADDLE_THROW(platform::errors::Unimplemented(
static_cast<int>(type)); "Not support proto::VarType::Type(%d) as tensor type.",
static_cast<int>(type)));
} }
size_t SizeOfType(proto::VarType::Type type) { size_t SizeOfType(proto::VarType::Type type) {
...@@ -91,7 +94,8 @@ size_t SizeOfType(proto::VarType::Type type) { ...@@ -91,7 +94,8 @@ size_t SizeOfType(proto::VarType::Type type) {
if (it != gDataTypeMap().proto_to_size_.end()) { if (it != gDataTypeMap().proto_to_size_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support %s as tensor type", DataTypeToString(type)); PADDLE_THROW(platform::errors::Unimplemented("Not support %s as tensor type.",
DataTypeToString(type)));
} }
} // namespace framework } // namespace framework
......
...@@ -78,7 +78,9 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { ...@@ -78,7 +78,9 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
_ForEachDataType_(VisitDataTypeCallback); _ForEachDataType_(VisitDataTypeCallback);
#undef VisitDataTypeCallback #undef VisitDataTypeCallback
PADDLE_THROW("Not supported %d", type); PADDLE_THROW(platform::errors::Unimplemented(
"Not supported proto::VarType::Type(%d) as data type.",
static_cast<int>(type)));
} }
template <typename Visitor> template <typename Visitor>
......
...@@ -56,7 +56,8 @@ struct CastDataType { ...@@ -56,7 +56,8 @@ struct CastDataType {
context->Wait(); context->Wait();
#endif #endif
} else { } else {
PADDLE_THROW("Unsupported place!"); PADDLE_THROW(platform::errors::Unimplemented(
"Place type is not supported when casting data type."));
} }
} }
}; };
...@@ -98,7 +99,9 @@ void TransDataType(const OpKernelType& kernel_type_for_var, ...@@ -98,7 +99,9 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx)); framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
break; break;
default: default:
PADDLE_THROW("Not support type %d", src_type); PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when casting data type.",
DataTypeToString(src_type)));
} }
} }
......
...@@ -81,9 +81,11 @@ bool contain_unknown_dim(const DDim& ddim) { ...@@ -81,9 +81,11 @@ bool contain_unknown_dim(const DDim& ddim) {
} }
DDim slice_ddim(const DDim& dim, int begin, int end) { DDim slice_ddim(const DDim& dim, int begin, int end) {
PADDLE_ENFORCE(begin >= 0 && end <= dim.size(), PADDLE_ENFORCE_EQ(
"[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.", (begin >= 0 && end <= dim.size()), true,
begin, end, dim.size()); platform::errors::InvalidArgument(
"[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.", begin,
end, dim.size()));
// Constructor of DDim would check whether end - begin is valid // Constructor of DDim would check whether end - begin is valid
return DDim(dim.Get() + begin, end - begin); return DDim(dim.Get() + begin, end - begin);
} }
......
...@@ -29,20 +29,23 @@ namespace framework { ...@@ -29,20 +29,23 @@ namespace framework {
return (callback); \ return (callback); \
} }
#define PADDLE_VISIT_DDIM(rank, callback) \ #define PADDLE_VISIT_DDIM(rank, callback) \
switch (rank) { \ switch (rank) { \
PADDLE_VISIT_DDIM_BASE(0, callback); \ PADDLE_VISIT_DDIM_BASE(0, callback); \
PADDLE_VISIT_DDIM_BASE(1, callback); \ PADDLE_VISIT_DDIM_BASE(1, callback); \
PADDLE_VISIT_DDIM_BASE(2, callback); \ PADDLE_VISIT_DDIM_BASE(2, callback); \
PADDLE_VISIT_DDIM_BASE(3, callback); \ PADDLE_VISIT_DDIM_BASE(3, callback); \
PADDLE_VISIT_DDIM_BASE(4, callback); \ PADDLE_VISIT_DDIM_BASE(4, callback); \
PADDLE_VISIT_DDIM_BASE(5, callback); \ PADDLE_VISIT_DDIM_BASE(5, callback); \
PADDLE_VISIT_DDIM_BASE(6, callback); \ PADDLE_VISIT_DDIM_BASE(6, callback); \
PADDLE_VISIT_DDIM_BASE(7, callback); \ PADDLE_VISIT_DDIM_BASE(7, callback); \
PADDLE_VISIT_DDIM_BASE(8, callback); \ PADDLE_VISIT_DDIM_BASE(8, callback); \
PADDLE_VISIT_DDIM_BASE(9, callback); \ PADDLE_VISIT_DDIM_BASE(9, callback); \
default: \ default: \
PADDLE_THROW("Invalid rank %d", rank); \ PADDLE_THROW(platform::errors::Unimplemented( \
"Invalid dimension to be accessed. Now only supports access to " \
"dimension 0 to 9, but received dimension is %d.", \
rank)); \
} }
template <typename T1, typename T2> template <typename T1, typename T2>
...@@ -92,13 +95,31 @@ class DDim { ...@@ -92,13 +95,31 @@ class DDim {
inline int64_t operator[](int idx) const { return dim_[idx]; } inline int64_t operator[](int idx) const { return dim_[idx]; }
inline int64_t& at(int idx) { int64_t& at(int idx) {
PADDLE_ENFORCE(idx >= 0 && idx < rank_, "Invalid idx %d", idx); PADDLE_ENFORCE_GE(idx, 0,
platform::errors::InvalidArgument(
"Invalid DDim index to be accessed. The valid index "
"is between 0 and %d, but received index is %d.",
rank_, idx));
PADDLE_ENFORCE_LT(idx, rank_,
platform::errors::InvalidArgument(
"Invalid DDim index to be accessed. The valid index "
"is between 0 and %d, but received index is %d.",
rank_, idx));
return dim_[idx]; return dim_[idx];
} }
inline int64_t at(int idx) const { int64_t at(int idx) const {
PADDLE_ENFORCE(idx >= 0 && idx < rank_, "Invalid idx %d", idx); PADDLE_ENFORCE_GE(idx, 0,
platform::errors::InvalidArgument(
"Invalid DDim index to be accessed. The valid index "
"is between 0 and %d, but received index is %d.",
rank_, idx));
PADDLE_ENFORCE_LT(idx, rank_,
platform::errors::InvalidArgument(
"Invalid DDim index to be accessed. The valid index "
"is between 0 and %d, but received index is %d.",
rank_, idx));
return dim_[idx]; return dim_[idx];
} }
......
...@@ -42,53 +42,18 @@ inline void InitVarsInScope(const std::vector<VarInfo> &var_infos, Scope *scope, ...@@ -42,53 +42,18 @@ inline void InitVarsInScope(const std::vector<VarInfo> &var_infos, Scope *scope,
} }
} }
// get RpcContext and remote send and recv op // get CommContext and remote send and recv op
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
using RpcCtxMap = operators::distributed::RpcCtxMap;
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx;
for (auto &node : graphs[0]->Nodes()) {
VLOG(3) << "node name " << node->Name();
if (node && node->IsOp()) {
if (node->Name() == "send") {
auto send_var_name = node->Op()->Input("X")[0];
auto send_varnames =
BOOST_GET_CONST(std::vector<std::string>,
node->Op()->GetNullableAttr("send_varnames"));
auto epmap = BOOST_GET_CONST(std::vector<std::string>,
node->Op()->GetNullableAttr("epmap"));
auto height_section = BOOST_GET_CONST(
std::vector<int64_t>, node->Op()->GetNullableAttr("sections"));
auto trainer_id =
BOOST_GET_CONST(int, node->Op()->GetNullableAttr("trainer_id"));
auto merge_add =
BOOST_GET_CONST(bool, node->Op()->GetNullableAttr("merge_add"));
if (!merge_add) {
merge_add = FLAGS_communicator_is_sgd_optimizer;
}
auto use_send_handler = BOOST_GET_CONST(
bool, node->Op()->GetNullableAttr("use_send_handler"));
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
send_var_name, send_varnames, epmap, height_section, trainer_id,
merge_add, use_send_handler);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
}
}
}
// init communicator here // init communicator here
if (send_varname_to_ctx.size() > 0) { auto *instance = operators::distributed::Communicator::GetInstance();
auto *instance = operators::distributed::Communicator::GetInstance(); auto initialized = instance ? true : false;
auto initialized = instance ? true : false; PADDLE_ENFORCE_EQ(initialized, true,
PADDLE_ENFORCE_EQ(initialized, true, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "Communicator is not Initialized, you may use "
"Communicator is not Initialized, you may use " "FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/"
"FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/" "develop/markdown_doc/transpiler)"));
"develop/markdown_doc/transpiler)"));
}
#endif #endif
} }
......
...@@ -111,6 +111,7 @@ void DeviceWorker::DumpParam(const Scope& scope, const int batch_id) { ...@@ -111,6 +111,7 @@ void DeviceWorker::DumpParam(const Scope& scope, const int batch_id) {
writer_ << os.str(); writer_ << os.str();
} }
} }
void DeviceWorker::InitRandomDumpConfig(const TrainerDesc& desc) { void DeviceWorker::InitRandomDumpConfig(const TrainerDesc& desc) {
bool enable_random_dump = desc.enable_random_dump(); bool enable_random_dump = desc.enable_random_dump();
if (!enable_random_dump) { if (!enable_random_dump) {
......
...@@ -335,6 +335,7 @@ class SectionWorker : public DeviceWorker { ...@@ -335,6 +335,7 @@ class SectionWorker : public DeviceWorker {
void SetSkipVars(const std::vector<std::string>& skip_vars) { void SetSkipVars(const std::vector<std::string>& skip_vars) {
skip_vars_ = skip_vars; skip_vars_ = skip_vars;
} }
static void ResetBatchId() { batch_id_ = 0; }
static std::atomic<int> cpu_id_; static std::atomic<int> cpu_id_;
......
...@@ -99,7 +99,7 @@ void DistMultiTrainer::InitTrainerEnv(const ProgramDesc &main_program, ...@@ -99,7 +99,7 @@ void DistMultiTrainer::InitTrainerEnv(const ProgramDesc &main_program,
} }
void DistMultiTrainer::InitOtherEnv(const ProgramDesc &main_program) { void DistMultiTrainer::InitOtherEnv(const ProgramDesc &main_program) {
if (need_dump_field_) { if (need_dump_field_ || need_dump_param_) {
InitDumpEnv(); InitDumpEnv();
} }
pull_dense_worker_->SetRootScope(root_scope_); pull_dense_worker_->SetRootScope(root_scope_);
...@@ -158,7 +158,7 @@ void DistMultiTrainer::Finalize() { ...@@ -158,7 +158,7 @@ void DistMultiTrainer::Finalize() {
} }
} }
if (need_dump_field_) { if (need_dump_field_ || need_dump_param_) {
FinalizeDumpEnv(); FinalizeDumpEnv();
} }
pull_dense_worker_->Stop(); pull_dense_worker_->Stop();
......
...@@ -49,7 +49,12 @@ TEST(DisMultiTrainerTest, test1) { ...@@ -49,7 +49,12 @@ TEST(DisMultiTrainerTest, test1) {
dataset->SetTrainerNum(1); dataset->SetTrainerNum(1);
dataset->SetDataFeedDesc(str); dataset->SetDataFeedDesc(str);
dataset->CreateReaders(); dataset->CreateReaders();
Scope root_scope;
tmp1->SetScope(&root_scope);
tmp1->Initialize(t, dataset.get()); tmp1->Initialize(t, dataset.get());
ProgramDesc p;
tmp1->InitOtherEnv(p);
tmp1->Finalize();
#endif #endif
} }
} // namespace framework } // namespace framework
......
...@@ -22,56 +22,104 @@ enum Mode { ...@@ -22,56 +22,104 @@ enum Mode {
HETER = 4; // support XPU and GPU computing server HETER = 4; // support XPU and GPU computing server
} }
message DistributedStrategy { message RecomputeConfig { repeated string checkpoints = 1; }
optional Mode mode = 1 [ default = COLLECTIVE ]; // just for serialization
// collective training strategy message AMPConfig {
optional bool amp = 2 [ default = false ]; optional float init_loss_scaling = 1 [ default = 32768.0 ];
optional int32 amp_loss_scaling = 3 [ default = 32768 ]; optional int32 incr_every_n_steps = 2 [ default = 1000 ];
optional bool recompute = 4 [ default = false ]; optional int32 decr_every_n_nan_or_inf = 3 [ default = 2 ];
repeated string recompute_checkpoints = 5; optional float incr_ratio = 4 [ default = 2.0 ];
optional bool localsgd = 6 [ default = false ]; optional float decr_ratio = 5 [ default = 0.8 ];
optional int32 localsgd_k_step = 7 [ default = 4 ]; optional bool use_dynamic_loss_scaling = 6 [ default = true ];
optional bool dgc = 8 [ default = false ]; repeated string custom_white_list = 7;
optional bool hierachical_allreduce = 9 [ default = false ]; repeated string custom_black_list = 8;
optional int32 nccl_comm_num = 10 [ default = 1 ]; repeated string custom_black_varnames = 9;
optional bool gradient_merge = 11 [ default = false ]; }
optional int32 gradient_merge_k_step = 12 [ default = 1 ];
optional bool sequential_execution = 13 [ default = false ]; message LocalSGDConfig { optional int32 k_steps = 1 [ default = 4 ]; }
optional bool enable_backward_optimizer_op_deps = 14 [ default = true ];
optional bool lars = 15 [ default = false ]; message GradientMergeConfig {
optional bool lamb = 16 [ default = false ]; optional int32 k_steps = 1 [ default = 1 ];
optional bool fuse_elewise_add_act_ops = 17 [ default = false ]; optional bool avg = 2 [ default = true ];
optional bool fuse_bn_act_ops = 18 [ default = false ]; }
optional bool enable_auto_fusion = 19 [ default = false ];
optional bool fuse_relu_depthwise_conv = 20 [ default = false ]; message LarsConfig {
optional bool enable_inplace = 21 [ default = false ]; optional float lars_coeff = 1 [ default = 0.001 ];
optional bool fuse_all_reduce_ops = 22 [ default = false ]; optional float lars_weight_decay = 2 [ default = 0.0005 ];
optional int32 num_iteration_per_drop_scope = 23 [ default = 1 ]; }
optional bool sync_batch_norm = 24 [ default = false ];
optional bool fuse_all_optimizer_ops = 25 [ default = false ];
// pipeline training message LambConfig {
optional bool pipeline = 101 [ default = false ]; optional float beta1 = 1 [ default = 0.001 ];
optional int32 pipeline_micro_batch = 102; optional float beta2 = 2 [ default = 0.999 ];
optional float epsilon = 3 [ default = 0.000001 ];
}
// parameter server training message BuildStrategy {
optional bool sync = 201 [ default = false ]; optional bool enable_sequential_execution = 1 [ default = false ];
optional bool async = 202 [ default = true ]; optional bool fuse_elewise_add_act_ops = 2 [ default = false ];
optional int32 async_k_step = 203 [ default = -1 ]; optional bool fuse_bn_act_ops = 3 [ default = false ];
optional int32 max_merge_var_num = 204 [ default = 1 ]; optional bool fuse_relu_depthwise_conv = 4 [ default = false ];
optional int32 send_queue_size = 205 [ default = 16 ]; optional bool fuse_broadcast_ops = 5 [ default = false ];
optional bool independent_recv_thread = 206 [ default = false ]; optional bool fuse_all_optimizer_ops = 6 [ default = false ];
optional int32 min_send_grad_num_before_recv = 207 [ default = 1 ]; optional bool enable_inplace = 7 [ default = false ];
optional int32 thread_pool_size = 208 [ default = 1 ]; optional bool enable_backward_optimizer_op_deps = 8 [ default = true ];
optional int32 send_wait_times = 209 [ default = 1 ]; optional bool cache_runtime_context = 9 [ default = false ];
optional bool runtime_split_send_recv = 210 [ default = false ]; }
optional bool use_thread_barrier = 211 [ default = false ];
// elastic deep learning strategies message ExecutionStrategy {
optional bool elastic = 301 [ default = false ]; optional int32 num_threads = 1 [ default = 1 ];
optional int32 num_iteration_per_drop_scope = 2 [ default = 10 ];
optional int32 num_iteration_per_run = 3 [ default = 1 ];
optional bool use_thread_barrier = 4 [ default = false ];
}
message AsyncConfig {
optional int32 k_steps = 1 [ default = 1 ];
optional int32 max_merge_var_num = 2 [ default = 1 ];
optional int32 send_queue_size = 3 [ default = 16 ];
optional bool independent_recv_thread = 4 [ default = false ];
optional int32 min_send_grad_num_before_recv = 5 [ default = 1 ];
optional int32 thread_pool_size = 6 [ default = 1 ];
optional int32 send_wait_times = 7 [ default = 1 ];
optional bool runtime_split_send_recv = 8 [ default = false ];
}
message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; }
message DistributedStrategy {
// bool options
optional Mode mode = 1 [ default = COLLECTIVE ];
optional bool amp = 2 [ default = false ];
optional bool recompute = 3 [ default = false ];
optional bool localsgd = 4 [ default = false ];
optional bool dgc = 5 [ default = false ];
optional bool gradient_merge = 6 [ default = false ];
optional bool lars = 7 [ default = false ];
optional bool lamb = 8 [ default = false ];
optional bool pipeline = 9 [ default = false ];
optional bool elastic = 10 [ default = false ];
optional bool auto = 11 [ default = false ];
optional bool a_sync = 12 [ default = true ];
optional bool sync_nccl_allreduce = 13 [ default = true ];
optional int32 nccl_comm_num = 14 [ default = 1 ];
optional bool use_hierarchical_allreduce = 15 [ default = false ];
optional int32 hierarchical_allreduce_inter_nranks = 16 [ default = 1 ];
optional bool sync_batch_norm = 17 [ default = false ];
optional bool fuse_all_reduce_ops = 18 [ default = true ];
optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ];
optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ];
// optional bool enable_backward_optimizer_op_deps = 19 [ default = true ];
// auto parallel optional RecomputeConfig recompute_configs = 101;
optional bool auto = 401 [ default = false ]; optional AMPConfig amp_configs = 102;
optional LocalSGDConfig localsgd_configs = 103;
optional GradientMergeConfig gradient_merge_configs = 104;
optional PipelineConfig pipeline_configs = 106;
optional AsyncConfig a_sync_configs = 107;
optional LarsConfig lars_configs = 108;
optional LambConfig lamb_configs = 109;
optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202;
} }
message DistributedJobInfo { message DistributedJobInfo {
......
...@@ -30,7 +30,10 @@ static ::DLDataType GetDLDataTypeCode() { ...@@ -30,7 +30,10 @@ static ::DLDataType GetDLDataTypeCode() {
} else if (std::is_integral<T>::value) { } else if (std::is_integral<T>::value) {
dtype.code = kDLInt; dtype.code = kDLInt;
} else { } else {
PADDLE_THROW("Unsupported data type %s", typeid(T).name()); PADDLE_THROW(platform::errors::Unavailable(
"Unsupported data type (%s), only supports float16, float, unsigned "
"int and int.",
platform::demangle(typeid(T).name())));
} }
dtype.bits = 8 * sizeof(T); dtype.bits = 8 * sizeof(T);
dtype.lanes = 1; dtype.lanes = 1;
...@@ -52,8 +55,9 @@ static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) { ...@@ -52,8 +55,9 @@ static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
static auto type_to_dtype_map = CreateDLDataTypeMap(); static auto type_to_dtype_map = CreateDLDataTypeMap();
static auto type_to_dtype_map_end_it = type_to_dtype_map.end(); static auto type_to_dtype_map_end_it = type_to_dtype_map.end();
auto it = type_to_dtype_map.find(static_cast<int>(type)); auto it = type_to_dtype_map.find(static_cast<int>(type));
PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %d", PADDLE_ENFORCE_NE(it, type_to_dtype_map_end_it,
type); platform::errors::InvalidArgument(
"Unsupported data type (%s).", DataTypeToString(type)));
return it->second; return it->second;
#undef REG_DL_DATA_TYPE #undef REG_DL_DATA_TYPE
} }
...@@ -73,7 +77,8 @@ struct DLContextVisitor : public boost::static_visitor<::DLContext> { ...@@ -73,7 +77,8 @@ struct DLContextVisitor : public boost::static_visitor<::DLContext> {
ctx.device_id = place.device; ctx.device_id = place.device;
return ctx; return ctx;
#else #else
PADDLE_THROW("platform::CUDAPlace is not supported in CPU only version"); PADDLE_THROW(platform::errors::Unavailable(
"platform::CUDAPlace is not supported in CPU only version."));
#endif #endif
} }
...@@ -84,8 +89,8 @@ struct DLContextVisitor : public boost::static_visitor<::DLContext> { ...@@ -84,8 +89,8 @@ struct DLContextVisitor : public boost::static_visitor<::DLContext> {
ctx.device_id = 0; ctx.device_id = 0;
return ctx; return ctx;
#else #else
PADDLE_THROW( PADDLE_THROW(platform::errors::Unavailable(
"platform::CUDAPinnedPlace is not supported in CPU only version"); "platform::CUDAPinnedPlace is not supported in CPU only version."));
#endif #endif
} }
}; };
...@@ -136,7 +141,10 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) { ...@@ -136,7 +141,10 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) {
// refer to cupy and cudf, the compact tensor first dim's strides need to be 1 // refer to cupy and cudf, the compact tensor first dim's strides need to be 1
// and second dim's strides need to be length of rows of cudf // and second dim's strides need to be length of rows of cudf
// cudf now only support dim=2 // cudf now only support dim=2
PADDLE_ENFORCE_LE(t_.ndim, 2, "cudf now only support dim=2."); PADDLE_ENFORCE_LE(t_.ndim, 2, platform::errors::InvalidArgument(
"cudf now only supports dimension is 2, "
"but received dimension is %d.",
t_.ndim));
if (t_.ndim > 1) if (t_.ndim > 1)
t_.strides = new int64_t[2]{1, t_.shape[1]}; t_.strides = new int64_t[2]{1, t_.shape[1]};
......
...@@ -556,9 +556,11 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -556,9 +556,11 @@ void DownpourWorker::TrainFilesWithProfiler() {
continue; continue;
} }
PADDLE_ENFORCE_EQ(framework::TensorContainsInf(*tensor), false, PADDLE_ENFORCE_EQ(framework::TensorContainsInf(*tensor), false,
"Tensor %s contains Inf", var_name); platform::errors::InvalidArgument(
"Tensor %s contains Inf.", var_name));
PADDLE_ENFORCE_EQ(framework::TensorContainsNAN(*tensor), false, PADDLE_ENFORCE_EQ(framework::TensorContainsNAN(*tensor), false,
"Tensor %s contains NAN", var_name); platform::errors::InvalidArgument(
"Tensor %s contains NAN.", var_name));
} }
if (need_to_push_sparse_) { if (need_to_push_sparse_) {
...@@ -829,9 +831,11 @@ void DownpourWorker::TrainFiles() { ...@@ -829,9 +831,11 @@ void DownpourWorker::TrainFiles() {
continue; continue;
} }
PADDLE_ENFORCE_EQ(framework::TensorContainsInf(*tensor), false, PADDLE_ENFORCE_EQ(framework::TensorContainsInf(*tensor), false,
"Tensor %s contains Inf", var_name); platform::errors::InvalidArgument(
"Tensor %s contains Inf.", var_name));
PADDLE_ENFORCE_EQ(framework::TensorContainsNAN(*tensor), false, PADDLE_ENFORCE_EQ(framework::TensorContainsNAN(*tensor), false,
"Tensor %s contains NAN", var_name); platform::errors::InvalidArgument(
"Tensor %s contains NAN.", var_name));
} }
if (need_to_push_sparse_) { if (need_to_push_sparse_) {
......
...@@ -26,7 +26,11 @@ struct EigenDim { ...@@ -26,7 +26,11 @@ struct EigenDim {
using Type = Eigen::DSizes<Eigen::DenseIndex, D>; using Type = Eigen::DSizes<Eigen::DenseIndex, D>;
static Type From(const DDim& dims) { static Type From(const DDim& dims) {
PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)"); PADDLE_ENFORCE_EQ(arity(dims), D,
platform::errors::InvalidArgument(
"Input dimension size should be equal to %d, but "
"received dimension size is %d.",
arity(dims), D));
Type ret; Type ret;
for (int64_t d = 0; d < arity(dims); d++) { for (int64_t d = 0; d < arity(dims); d++) {
ret[d] = dims[d]; ret[d] = dims[d];
...@@ -69,8 +73,11 @@ struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> { ...@@ -69,8 +73,11 @@ struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {
static typename EigenMatrix::Type Reshape(Tensor& tensor, // NOLINT static typename EigenMatrix::Type Reshape(Tensor& tensor, // NOLINT
int num_col_dims) { int num_col_dims) {
int rank = tensor.dims_.size(); int rank = tensor.dims_.size();
PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank, PADDLE_ENFORCE_EQ((num_col_dims > 0 && num_col_dims < rank), true,
"`num_col_dims` must be between (0, rank_of_tensor)."); platform::errors::InvalidArgument(
"Input dimension number(num_col_dims) must be "
"between 0 and %d, but received number is %d.",
rank, num_col_dims));
return EigenMatrix::From(tensor, return EigenMatrix::From(tensor,
flatten_to_2d(tensor.dims(), num_col_dims)); flatten_to_2d(tensor.dims(), num_col_dims));
} }
...@@ -78,8 +85,11 @@ struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> { ...@@ -78,8 +85,11 @@ struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {
static typename EigenMatrix::ConstType Reshape(const Tensor& tensor, static typename EigenMatrix::ConstType Reshape(const Tensor& tensor,
int num_col_dims) { int num_col_dims) {
int rank = tensor.dims_.size(); int rank = tensor.dims_.size();
PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank, PADDLE_ENFORCE_EQ((num_col_dims > 0 && num_col_dims < rank), true,
"`num_col_dims` must be between (0, rank_of_tensor)."); platform::errors::InvalidArgument(
"Input dimension number(num_col_dims) must be "
"between 0 and %d, but received number is %d.",
rank, num_col_dims));
return EigenMatrix::From(tensor, return EigenMatrix::From(tensor,
flatten_to_2d(tensor.dims(), num_col_dims)); flatten_to_2d(tensor.dims(), num_col_dims));
} }
......
...@@ -37,9 +37,12 @@ limitations under the License. */ ...@@ -37,9 +37,12 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run"); DECLARE_bool(use_mkldnn);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -83,14 +86,7 @@ Executor::~Executor() { ...@@ -83,14 +86,7 @@ Executor::~Executor() {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache, // Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working // this is needed to have mkl-dnn unit tests working
if (platform::is_cpu_place(place_)) { ClearMKLDNNCache(place_);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place_);
dev_ctx->ResetBlobMap();
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW);
}
#endif #endif
} }
......
...@@ -175,8 +175,9 @@ void DeleteUnusedTensors( ...@@ -175,8 +175,9 @@ void DeleteUnusedTensors(
garbages.emplace_back(t.MoveMemoryHolder()); garbages.emplace_back(t.MoveMemoryHolder());
} }
} else { } else {
PADDLE_THROW("Type %s of %s is not supported eager deletion", PADDLE_THROW(platform::errors::Unimplemented(
framework::ToTypeName(var->Type()), var_name); "Type %s of variable %s is not supported eager deletion.",
framework::ToTypeName(var->Type()), var_name));
} }
} }
......
...@@ -79,15 +79,15 @@ StreamGarbageCollector::StreamGarbageCollector(const platform::CUDAPlace &place, ...@@ -79,15 +79,15 @@ StreamGarbageCollector::StreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size) size_t max_memory_size)
: GarbageCollector(place, max_memory_size) { : GarbageCollector(place, max_memory_size) {
platform::CUDADeviceGuard guard(place.device); platform::CUDADeviceGuard guard(place.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream_));
callback_manager_.reset(new platform::StreamCallbackManager(stream_)); callback_manager_.reset(new platform::StreamCallbackManager(stream_));
} }
StreamGarbageCollector::~StreamGarbageCollector() { StreamGarbageCollector::~StreamGarbageCollector() {
auto place = BOOST_GET_CONST(platform::CUDAPlace, this->dev_ctx_->GetPlace()); auto place = BOOST_GET_CONST(platform::CUDAPlace, this->dev_ctx_->GetPlace());
platform::CUDADeviceGuard guard(place.device); platform::CUDADeviceGuard guard(place.device);
PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaStreamDestroy(stream_)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(stream_));
} }
cudaStream_t StreamGarbageCollector::stream() const { return stream_; } cudaStream_t StreamGarbageCollector::stream() const { return stream_; }
......
...@@ -96,14 +96,14 @@ class GradOpDescMakerBase { ...@@ -96,14 +96,14 @@ class GradOpDescMakerBase {
if (!drop_empty_grad) { if (!drop_empty_grad) {
return ret_val; return ret_val;
} }
PADDLE_ENFORCE_LE(var_names.size(), 1UL, PADDLE_ENFORCE_LE(
"BUG from operator developer:" var_names.size(), 1UL,
" for input argument with a list of variables, " platform::errors::Unavailable(
" drop_empty_grad is not allowed because it makes" "BUG from operator developer:"
" the correspondence bewteen a variable and its gradient" " for input argument with a list of variables, "
" ambiguous." " drop_empty_grad is not allowed because it makes"
" Op type %s", " the correspondence bewteen a variable and its gradient"
fwd_op_.Type()); " ambiguous."));
std::vector<std::string> dropped_ret_val; std::vector<std::string> dropped_ret_val;
dropped_ret_val.reserve(ret_val.size()); dropped_ret_val.reserve(ret_val.size());
...@@ -157,7 +157,8 @@ class GradOpDescMakerBase { ...@@ -157,7 +157,8 @@ class GradOpDescMakerBase {
const Attribute& GetAttr(const std::string& name) const { const Attribute& GetAttr(const std::string& name) const {
auto& map = fwd_op_.GetAttrMap(); auto& map = fwd_op_.GetAttrMap();
auto it = map.find(name); auto it = map.find(name);
PADDLE_ENFORCE(it != map.end(), "Cannot find attribute %s", name); PADDLE_ENFORCE_NE(it, map.end(), platform::errors::NotFound(
"Cannot find attribute (%s).", name));
return it->second; return it->second;
} }
......
...@@ -53,7 +53,9 @@ void HogwildWorker::CreateThreadScope(const ProgramDesc &program) { ...@@ -53,7 +53,9 @@ void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
auto &block = program.Block(0); auto &block = program.Block(0);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
root_scope_, "root_scope should be set before creating thread scope"); root_scope_,
platform::errors::NotFound(
"Root scope should be set before creating thread scope."));
thread_scope_ = &root_scope_->NewScope(); thread_scope_ = &root_scope_->NewScope();
......
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include <fcntl.h> #include <fcntl.h>
#include <sys/stat.h> #include <sys/stat.h>
#ifdef _WIN32 #ifdef _WIN32
#ifndef NOMINMAX
#define NOMINMAX // msvc max/min macro conflict with std::min/max
#endif
#include <windows.h> #include <windows.h>
#else #else
#include <sys/syscall.h> #include <sys/syscall.h>
......
...@@ -4,7 +4,7 @@ file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeList ...@@ -4,7 +4,7 @@ file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeList
file(APPEND ${pass_file} "\#pragma once\n") file(APPEND ${pass_file} "\#pragma once\n")
file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n") file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
copy_if_different(${pass_file} ${pass_file_final} extern_glog) copy_if_different(${pass_file} ${pass_file_final})
add_subdirectory(fuse_optimizer_ops_pass) add_subdirectory(fuse_optimizer_ops_pass)
add_subdirectory(memory_optimize_pass) add_subdirectory(memory_optimize_pass)
......
...@@ -26,15 +26,15 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -26,15 +26,15 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::FC fc_pattern(pattern, name_scope);
patterns::GRU gru_pattern(pattern, name_scope);
PDNode* x = PDNode* x =
pattern->NewNode(patterns::UniqueKey("x"))->assert_var_not_persistable(); pattern->NewNode(patterns::UniqueKey("x"))->assert_var_not_persistable();
// Create pattern.
patterns::FC fc_pattern(pattern, name_scope);
auto* fc_out = fc_pattern(x, with_fc_bias, /* with_relu */ false); auto* fc_out = fc_pattern(x, with_fc_bias, /* with_relu */ false);
fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse. fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse.
patterns::GRU gru_pattern(pattern, name_scope);
gru_pattern(fc_out); gru_pattern(fc_out);
// Create New OpDesc // Create New OpDesc
...@@ -48,17 +48,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -48,17 +48,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
SET_IN(X, x); SET_IN(X, x);
SET_IN(WeightX, weight_x); SET_IN(WeightX, weight_x);
SET_IN(WeightH, weight_h); SET_IN(WeightH, weight_h);
if (with_fc_bias) { SET_IN(Bias, bias);
op_desc.SetInput("Bias", {NEW_NAME(bias) + bias->Name()});
} else {
SET_IN(Bias, bias);
}
#undef SET_IN #undef SET_IN
// TODO(grygielski): Add H0 to the pass
op_desc.SetInput("H0", {}); op_desc.SetInput("H0", {});
op_desc.SetOutput("Hidden", {hidden->Name()}); op_desc.SetOutput("Hidden", {hidden->Name()});
op_desc.SetAttr("is_reverse", gru->Op()->GetAttr("is_reverse")); op_desc.SetAttr("is_reverse", gru->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("origin_mode",
gru->Op()->GetAttrIfExists<bool>("origin_mode"));
// TODO(TJ): This should be a option for infer // TODO(TJ): This should be a option for infer
op_desc.SetAttr("use_seq", true); op_desc.SetAttr("use_seq", true);
op_desc.SetAttr("activation", gru->Op()->GetAttr("activation"));
op_desc.SetAttr("gate_activation", gru->Op()->GetAttr("gate_activation"));
#define SET_IMTERMEDIATE_OUT(key) op_desc.SetOutput(#key, {NEW_NAME(key)}) #define SET_IMTERMEDIATE_OUT(key) op_desc.SetOutput(#key, {NEW_NAME(key)})
SET_IMTERMEDIATE_OUT(ReorderedH0); SET_IMTERMEDIATE_OUT(ReorderedH0);
...@@ -68,35 +69,30 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -68,35 +69,30 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef SET_IMTERMEDIATE_OUT #undef SET_IMTERMEDIATE_OUT
auto* op = graph->CreateOpNode(&op_desc); auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE_EQ(graph->Has(kParamScopeAttr), true,
platform::errors::InvalidArgument(
"Graph have no attr kParamScopeAttr."));
auto& scope = graph->Get<Scope>(kParamScopeAttr);
if (with_fc_bias) { if (with_fc_bias) {
// Fusion GRU bias = fcbias + grubias auto* gru_bias_var = scope->FindVar(bias->Name());
auto* fusion_bias_var = scope.Var(NEW_NAME(bias) + bias->Name()); auto* fc_bias_var = scope->FindVar(fc_bias->Name());
auto* out_bias_tensor = PADDLE_ENFORCE_NE(
fusion_bias_var->GetMutable<framework::LoDTensor>(); gru_bias_var, nullptr,
PADDLE_ENFORCE_NOT_NULL( platform::errors::NotFound("GRU bias var has not been found."));
fusion_bias_var, PADDLE_ENFORCE_NE(
platform::errors::InvalidArgument( fc_bias_var, nullptr,
"Fusion bias variable's pointer cannot be nullptr.")); platform::errors::NotFound("FC bias var has not been found."));
auto* gru_bias_var = scope.FindVar(bias->Name());
auto* fc_bias_var = scope.FindVar(fc_bias->Name()); auto* gru_bias_tensor = gru_bias_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(gru_bias_var, auto* fc_bias_tensor = fc_bias_var->GetMutable<LoDTensor>();
platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(
"Gru bias var ptr cannot be nullptr.")); gru_bias_tensor->numel(), fc_bias_tensor->numel(),
PADDLE_ENFORCE_NOT_NULL(fc_bias_var, platform::errors::PreconditionNotMet(
platform::errors::InvalidArgument( "GRU and FC biases have to have equal number of elements."));
"Fc bias var ptr cannot be nullptr."));
const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>(); auto gru_bias_data =
const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>(); gru_bias_tensor->mutable_data<float>(platform::CPUPlace());
// new bias = fc bias + gru bias auto* fc_bias_data = fc_bias_tensor->data<float>();
out_bias_tensor->Resize(gru_bias_tenosr.dims());
auto* data = out_bias_tensor->mutable_data<float>(platform::CPUPlace()); // Recompute GRU bias
for (int i = 0; i < out_bias_tensor->numel(); i++) { for (int i = 0; i < gru_bias_tensor->numel(); ++i) {
data[i] = gru_bias_data[i] += fc_bias_data[i];
fc_bias_tensor.data<float>()[i] + gru_bias_tenosr.data<float>()[i];
} }
} }
#undef GET_NODE #undef GET_NODE
...@@ -117,7 +113,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -117,7 +113,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
IR_NODE_LINK_TO(x, op); IR_NODE_LINK_TO(x, op);
IR_NODE_LINK_TO(weight_x, op); IR_NODE_LINK_TO(weight_x, op);
IR_NODE_LINK_TO(weight_h, op); IR_NODE_LINK_TO(weight_h, op);
IR_NODE_LINK_TO(bias, op); // actually should link to new bias if have IR_NODE_LINK_TO(bias, op);
IR_NODE_LINK_TO(op, hidden); IR_NODE_LINK_TO(op, hidden);
// h0? // h0?
return op; return op;
......
...@@ -320,7 +320,7 @@ std::vector<Node *> FuseBatchNormActPass::ReplaceNode( ...@@ -320,7 +320,7 @@ std::vector<Node *> FuseBatchNormActPass::ReplaceNode(
return node; return node;
}); });
PADDLE_ENFORCE_EQ(has_replaced, true, PADDLE_ENFORCE_EQ(has_replaced, true,
platform::errors::NotFound("Not find %s in the node list.", platform::errors::NotFound("Not found %s in the node list.",
cur_node->Name())); cur_node->Name()));
return new_list; return new_list;
} }
......
...@@ -42,7 +42,8 @@ void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const { ...@@ -42,7 +42,8 @@ void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const {
// ele_add(x, act(y)) // ele_add(x, act(y))
ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct( ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const { ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("elewise_add_act", graph); FusePassBase::Init("elewise_add_act", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -93,7 +94,8 @@ ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct( ...@@ -93,7 +94,8 @@ ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct(
// act(ele_add(x,y)) // act(ele_add(x,y))
ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd( ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const { ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("act_elewise_add", graph); FusePassBase::Init("act_elewise_add", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -145,7 +147,8 @@ ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd( ...@@ -145,7 +147,8 @@ ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd(
// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"] // ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"]
ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad( ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const { ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("elewise_add_act_grad", graph); FusePassBase::Init("elewise_add_act_grad", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -252,10 +255,11 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { ...@@ -252,10 +255,11 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
bool save_intermediate_out = BOOST_GET_CONST( bool save_intermediate_out = BOOST_GET_CONST(
bool, cur_node->Op()->GetAttr("save_intermediate_out")); bool, cur_node->Op()->GetAttr("save_intermediate_out"));
auto intermediate_out_args = cur_node->Op()->Output("IntermediateOut"); auto intermediate_out_args = cur_node->Op()->Output("IntermediateOut");
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
save_intermediate_out && !intermediate_out_args.empty(), (save_intermediate_out && !intermediate_out_args.empty()), true,
"The %s should save the intermediate_out in the fusing stage.", platform::errors::InvalidArgument(
cur_node->Name()); "The %s should save the intermediate out in the fusing stage.",
cur_node->Name()));
// If the intermediate_out's output is empty, it should be removed. // If the intermediate_out's output is empty, it should be removed.
auto cur_node_outputs = cur_node->outputs; auto cur_node_outputs = cur_node->outputs;
...@@ -271,10 +275,11 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { ...@@ -271,10 +275,11 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
} else if (cur_node->Name() == "fused_elemwise_activation_grad") { } else if (cur_node->Name() == "fused_elemwise_activation_grad") {
auto intermediate_out_grad_args = auto intermediate_out_grad_args =
cur_node->Op()->Output(GradVarName("IntermediateOut")); cur_node->Op()->Output(GradVarName("IntermediateOut"));
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
!intermediate_out_grad_args.empty(), intermediate_out_grad_args.empty(), false,
"The %s should save the intermediate_out in the fusing stage.", platform::errors::InvalidArgument(
cur_node->Name()); "The %s should save the intermediate out in the fusing stage.",
cur_node->Name()));
auto cur_node_outputs = cur_node->outputs; auto cur_node_outputs = cur_node->outputs;
// If the intermediate_out_g's output is empty, it should be removed. // If the intermediate_out_g's output is empty, it should be removed.
for (auto &out : cur_node_outputs) { for (auto &out : cur_node_outputs) {
...@@ -312,7 +317,11 @@ void FuseElewiseAddActPass::ReLinkNodes(Graph *graph, ...@@ -312,7 +317,11 @@ void FuseElewiseAddActPass::ReLinkNodes(Graph *graph,
nodes2delete.emplace(out); nodes2delete.emplace(out);
} }
} else { } else {
PADDLE_ENFORCE(out == intermediate_out); PADDLE_ENFORCE_EQ(
out, intermediate_out,
platform::errors::InvalidArgument(
"Output of op(%s) must be %s, but not %s.", op_1->Name(),
intermediate_out->Name(), out->Name()));
IR_OP_VAR_LINK(fused_op, out); IR_OP_VAR_LINK(fused_op, out);
} }
} }
...@@ -347,8 +356,9 @@ std::vector<Node *> FuseElewiseAddActPass::ReplaceNode( ...@@ -347,8 +356,9 @@ std::vector<Node *> FuseElewiseAddActPass::ReplaceNode(
} }
return node; return node;
}); });
PADDLE_ENFORCE(has_replaced, "Not find %s in the node list.", PADDLE_ENFORCE_EQ(has_replaced, true,
cur_node->Name()); platform::errors::NotFound("Not found %s in the node list.",
cur_node->Name()));
return new_list; return new_list;
} }
......
...@@ -25,14 +25,19 @@ void FusePassBase::Init(const std::string& repr, Graph* graph) const { ...@@ -25,14 +25,19 @@ void FusePassBase::Init(const std::string& repr, Graph* graph) const {
} }
Scope* FusePassBase::param_scope() const { Scope* FusePassBase::param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr)); PADDLE_ENFORCE_EQ(graph_->Has(kParamScopeAttr), true,
platform::errors::InvalidArgument(
"Graph must have kParamScopeAttr attribute."));
auto& scope = graph_->Get<framework::Scope>(kParamScopeAttr); auto& scope = graph_->Get<framework::Scope>(kParamScopeAttr);
return &scope; return &scope;
} }
void FusePassBase::AddStatis(int count_of_fused) const { void FusePassBase::AddStatis(int count_of_fused) const {
PADDLE_ENFORCE(graph_); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(!repr_.empty()); graph_, platform::errors::InvalidArgument("Graph cannot be nullptr."));
PADDLE_ENFORCE_EQ(repr_.empty(), false,
platform::errors::InvalidArgument(
"Fuse pass must be initialized with a name."));
if (!graph_->Has(kFuseStatisAttr)) { if (!graph_->Has(kFuseStatisAttr)) {
graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>); graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>);
} }
......
...@@ -31,7 +31,8 @@ void FuseReluDepthwiseConvPass::ApplyImpl(ir::Graph *graph) const { ...@@ -31,7 +31,8 @@ void FuseReluDepthwiseConvPass::ApplyImpl(ir::Graph *graph) const {
ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv( ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
ir::Graph *graph, bool only_forward) const { ir::Graph *graph, bool only_forward) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
if (only_forward) if (only_forward)
FusePassBase::Init("relu_depthwise_conv_only_forward", graph); FusePassBase::Init("relu_depthwise_conv_only_forward", graph);
else else
...@@ -110,23 +111,45 @@ ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv( ...@@ -110,23 +111,45 @@ ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
xg_var = subgraph.at(xg)->Var(); xg_var = subgraph.at(xg)->Var();
} }
PADDLE_ENFORCE_EQ(layer_op->Input("Input").size(), 1UL); PADDLE_ENFORCE_EQ(layer_op->Input("Input").size(), 1UL,
PADDLE_ENFORCE_EQ(layer_op->Input("Input")[0], y_var->Name()); platform::errors::InvalidArgument(
"Op(%s)'s input size(%d) must be 1.",
layer_op->Type(), layer_op->Input("Input").size()));
PADDLE_ENFORCE_EQ(
layer_op->Input("Input")[0], y_var->Name(),
platform::errors::InvalidArgument(
"Op(%s)'s input name(%s) must be %s.", layer_op->Type(),
layer_op->Input("Input")[0], y_var->Name()));
layer_op->SetInput("Input", {x_var->Name()}); layer_op->SetInput("Input", {x_var->Name()});
subgraph.at(layer)->inputs.push_back(subgraph.at(x)); subgraph.at(layer)->inputs.push_back(subgraph.at(x));
subgraph.at(x)->outputs.push_back(subgraph.at(layer)); subgraph.at(x)->outputs.push_back(subgraph.at(layer));
VLOG(4) << "replace " << y_var->Name() << " -> " << x_var->Name(); VLOG(4) << "replace " << y_var->Name() << " -> " << x_var->Name();
if (!only_forward) { if (!only_forward) {
PADDLE_ENFORCE_EQ(layer_g_op->Input("Input").size(), 1UL); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(layer_g_op->Input("Input")[0], y_var->Name()); layer_g_op->Input("Input").size(), 1UL,
platform::errors::InvalidArgument(
"Op(%s)'s input size(%d) must be 1.", layer_g_op->Type(),
layer_g_op->Input("Input").size()));
PADDLE_ENFORCE_EQ(
layer_g_op->Input("Input")[0], y_var->Name(),
platform::errors::InvalidArgument(
"Op(%s)'s input name(%s) must be %s.", layer_g_op->Type(),
layer_g_op->Input("Input")[0], y_var->Name()));
layer_g_op->SetInput("Input", {x_var->Name()}); layer_g_op->SetInput("Input", {x_var->Name()});
subgraph.at(layer_g)->inputs.push_back(subgraph.at(x)); subgraph.at(layer_g)->inputs.push_back(subgraph.at(x));
subgraph.at(x)->outputs.push_back(subgraph.at(layer_g)); subgraph.at(x)->outputs.push_back(subgraph.at(layer_g));
PADDLE_ENFORCE_EQ(layer_g_op->Output(GradVarName("Input")).size(), 1UL); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(layer_g_op->Output(GradVarName("Input"))[0], layer_g_op->Output(GradVarName("Input")).size(), 1UL,
yg_var->Name()); platform::errors::InvalidArgument(
"Op(%s)'s input size(%d) must be 1.", layer_g_op->Type(),
layer_g_op->Output(GradVarName("Input")).size()));
PADDLE_ENFORCE_EQ(
layer_g_op->Output(GradVarName("Input"))[0], yg_var->Name(),
platform::errors::InvalidArgument(
"Op(%s)'s input name(%s) must be %s.", layer_g_op->Type(),
layer_g_op->Output(GradVarName("Input"))[0], yg_var->Name()));
layer_g_op->SetOutput(GradVarName("Input"), {xg_var->Name()}); layer_g_op->SetOutput(GradVarName("Input"), {xg_var->Name()});
subgraph.at(layer_g)->outputs.push_back(subgraph.at(xg)); subgraph.at(layer_g)->outputs.push_back(subgraph.at(xg));
subgraph.at(xg)->inputs.push_back(subgraph.at(layer_g)); subgraph.at(xg)->inputs.push_back(subgraph.at(layer_g));
......
...@@ -136,7 +136,9 @@ bool FindCircleSubGraph(const Graph &graph, ...@@ -136,7 +136,9 @@ bool FindCircleSubGraph(const Graph &graph,
std::vector<ir::Node *> TopologySortOperations(const Graph &graph) { std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp> std::map<ir::Node *, std::set<ir::Node *, ir::NodeComp>, ir::NodeComp>
adj_list = BuildOperationAdjList(graph); adj_list = BuildOperationAdjList(graph);
PADDLE_ENFORCE(!HasCircleInternal(adj_list, nullptr)); PADDLE_ENFORCE_EQ(HasCircleInternal(adj_list, nullptr), false,
platform::errors::InvalidArgument(
"Generated graph shouldn't contain cycle."));
std::unordered_set<ir::Node *> visited; std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret; std::vector<ir::Node *> ret;
for (auto adj : adj_list) { for (auto adj : adj_list) {
...@@ -161,7 +163,11 @@ BuildOperationAdjList(const Graph &graph) { ...@@ -161,7 +163,11 @@ BuildOperationAdjList(const Graph &graph) {
} }
for (auto &var : n->inputs) { for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) { for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); PADDLE_ENFORCE_EQ(
adj_n->NodeType(), ir::Node::Type::kOperation,
platform::errors::InvalidArgument(
"Node(%s)'s type(%d) must be kOperation type.", adj_n->Name(),
static_cast<int>(adj_n->NodeType())));
VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n) VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n) << " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var); << " via " << var->Name() << reinterpret_cast<void *>(var);
...@@ -184,7 +190,11 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationOutAdjList( ...@@ -184,7 +190,11 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationOutAdjList(
} }
for (auto &var : n->outputs) { for (auto &var : n->outputs) {
for (auto &adj_n : var->outputs) { for (auto &adj_n : var->outputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); PADDLE_ENFORCE_EQ(
adj_n->NodeType(), ir::Node::Type::kOperation,
platform::errors::InvalidArgument(
"Node(%s)'s type(%d) must be kOperation type.", adj_n->Name(),
static_cast<int>(adj_n->NodeType())));
VLOG(40) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n) VLOG(40) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n) << " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var); << " via " << var->Name() << reinterpret_cast<void *>(var);
...@@ -359,7 +369,10 @@ size_t GraphNum(const Graph &graph) { ...@@ -359,7 +369,10 @@ size_t GraphNum(const Graph &graph) {
} }
std::unique_ptr<std::ostream> fout( std::unique_ptr<std::ostream> fout(
new std::ofstream(FLAGS_print_sub_graph_dir)); new std::ofstream(FLAGS_print_sub_graph_dir));
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE_EQ(fout->good(), true,
platform::errors::Unavailable(
"Can not open file %s for printing the graph.",
FLAGS_print_sub_graph_dir));
*fout << out.str(); *fout << out.str();
} }
} }
......
...@@ -37,12 +37,14 @@ NodesDFSIterator::NodesDFSIterator(const NodesDFSIterator &other) ...@@ -37,12 +37,14 @@ NodesDFSIterator::NodesDFSIterator(const NodesDFSIterator &other)
: stack_(other.stack_), visited_(other.visited_) {} : stack_(other.stack_), visited_(other.visited_) {}
Node &NodesDFSIterator::operator*() { Node &NodesDFSIterator::operator*() {
PADDLE_ENFORCE(!stack_.empty()); PADDLE_ENFORCE_EQ(stack_.empty(), false, platform::errors::OutOfRange(
"The iterator exceeds range."));
return *stack_.top(); return *stack_.top();
} }
NodesDFSIterator &NodesDFSIterator::operator++() { NodesDFSIterator &NodesDFSIterator::operator++() {
PADDLE_ENFORCE(!stack_.empty(), "the iterator exceeds range"); PADDLE_ENFORCE_EQ(stack_.empty(), false, platform::errors::OutOfRange(
"The iterator exceeds range."));
visited_.insert(stack_.top()); visited_.insert(stack_.top());
auto *cur = stack_.top(); auto *cur = stack_.top();
stack_.pop(); stack_.pop();
...@@ -73,11 +75,18 @@ inline bool CheckNodeIndegreeEquals(const Node &node, size_t n) { ...@@ -73,11 +75,18 @@ inline bool CheckNodeIndegreeEquals(const Node &node, size_t n) {
} }
NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) { NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) {
PADDLE_ENFORCE(!source.empty(), PADDLE_ENFORCE_EQ(
"Start points of topological sorting should not be empty!"); source.empty(), false,
platform::errors::InvalidArgument(
"Start points of topological sorting should not be empty!"));
// CHECK all the inputs' in-degree is 0 // CHECK all the inputs' in-degree is 0
for (auto *node : source) { for (auto *node : source) {
PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0)); PADDLE_ENFORCE_EQ(
CheckNodeIndegreeEquals(*node, 0), true,
platform::errors::InvalidArgument(
"In start points of topological sorting, the indegree of each "
"point should be 0. Node(%s)'s indegree is not 0.",
node->Name()));
} }
std::set<Node *> to_visit{source.begin(), source.end()}; std::set<Node *> to_visit{source.begin(), source.end()};
...@@ -106,7 +115,11 @@ NodesTSIterator::NodesTSIterator(const NodesTSIterator &other) ...@@ -106,7 +115,11 @@ NodesTSIterator::NodesTSIterator(const NodesTSIterator &other)
: sorted_(other.sorted_), cursor_(other.cursor_) {} : sorted_(other.sorted_), cursor_(other.cursor_) {}
Node &NodesTSIterator::operator*() { Node &NodesTSIterator::operator*() {
PADDLE_ENFORCE_LT(cursor_, sorted_.size()); PADDLE_ENFORCE_LT(
cursor_, sorted_.size(),
platform::errors::OutOfRange(
"The iterator exceeds range. Container size is %d, but index is %d.",
sorted_.size(), cursor_));
return *sorted_[cursor_]; return *sorted_[cursor_];
} }
...@@ -128,7 +141,11 @@ bool NodesTSIterator::operator==(const NodesTSIterator &other) { ...@@ -128,7 +141,11 @@ bool NodesTSIterator::operator==(const NodesTSIterator &other) {
} }
Node *NodesTSIterator::operator->() { Node *NodesTSIterator::operator->() {
PADDLE_ENFORCE_LT(cursor_, sorted_.size()); PADDLE_ENFORCE_LT(
cursor_, sorted_.size(),
platform::errors::OutOfRange(
"The iterator exceeds range. Container size is %d, but index is %d.",
sorted_.size(), cursor_));
return sorted_[cursor_]; return sorted_[cursor_];
} }
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#pragma once #pragma once
#include <stack> #include <stack>
#include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -66,7 +68,7 @@ struct NodesDFSIterator ...@@ -66,7 +68,7 @@ struct NodesDFSIterator
struct NodesTSIterator struct NodesTSIterator
: public std::iterator<std::forward_iterator_tag, Node *> { : public std::iterator<std::forward_iterator_tag, Node *> {
NodesTSIterator() = default; NodesTSIterator() = default;
NodesTSIterator(const std::vector<Node *> &source); explicit NodesTSIterator(const std::vector<Node *> &source);
NodesTSIterator(NodesTSIterator &&other) NodesTSIterator(NodesTSIterator &&other)
: sorted_(std::move(other.sorted_)), cursor_(other.cursor_) { : sorted_(std::move(other.sorted_)), cursor_(other.cursor_) {
other.cursor_ = 0; other.cursor_ = 0;
...@@ -104,7 +106,10 @@ struct GraphTraits { ...@@ -104,7 +106,10 @@ struct GraphTraits {
static iterator_range<NodesTSIterator> TS(const Graph &g) { static iterator_range<NodesTSIterator> TS(const Graph &g) {
auto start_points = ExtractStartPoints(g); auto start_points = ExtractStartPoints(g);
PADDLE_ENFORCE(!start_points.empty()); PADDLE_ENFORCE_EQ(
start_points.empty(), false,
platform::errors::InvalidArgument(
"Start points of topological sorting should not be empty!"));
NodesTSIterator x(start_points); NodesTSIterator x(start_points);
return iterator_range<NodesTSIterator>(NodesTSIterator(start_points), return iterator_range<NodesTSIterator>(NodesTSIterator(start_points),
NodesTSIterator()); NodesTSIterator());
......
...@@ -42,7 +42,10 @@ void GraphVizPass::ApplyImpl(ir::Graph* graph) const { ...@@ -42,7 +42,10 @@ void GraphVizPass::ApplyImpl(ir::Graph* graph) const {
const std::string& graph_viz_path = Get<std::string>(kGraphvizPath); const std::string& graph_viz_path = Get<std::string>(kGraphvizPath);
VLOG(3) << "draw IR graph viz to " << graph_viz_path; VLOG(3) << "draw IR graph viz to " << graph_viz_path;
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path)); std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE_EQ(
fout->good(), true,
platform::errors::Unavailable(
"Can not open file %s for printing the graph.", graph_viz_path));
std::ostream& sout = *fout; std::ostream& sout = *fout;
std::unordered_map<const ir::Node*, std::string> node2dot; std::unordered_map<const ir::Node*, std::string> node2dot;
......
...@@ -64,7 +64,11 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const { ...@@ -64,7 +64,11 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const {
for (auto& parameter : *pre_op_desc->Proto()->mutable_outputs()) { for (auto& parameter : *pre_op_desc->Proto()->mutable_outputs()) {
auto* arguments = parameter.mutable_arguments(); auto* arguments = parameter.mutable_arguments();
auto it = std::find(arguments->begin(), arguments->end(), scale_in_name); auto it = std::find(arguments->begin(), arguments->end(), scale_in_name);
PADDLE_ENFORCE(it != arguments->end()); PADDLE_ENFORCE_NE(
it, arguments->end(),
platform::errors::NotFound(
"Can not find input variable(%s) from scale op(%s).",
scale_in_name, pre_op_desc->Type()));
*it = scale_out_name; *it = scale_out_name;
} }
......
...@@ -33,7 +33,8 @@ const char kSumGradOpName[] = "sum"; ...@@ -33,7 +33,8 @@ const char kSumGradOpName[] = "sum";
const char kOptimizerType[] = "sgd"; const char kOptimizerType[] = "sgd";
void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const { void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
// We could collect all weights' name from SGD, where // We could collect all weights' name from SGD, where
// W1 <- SGD(W0, Grad0) // W1 <- SGD(W0, Grad0)
...@@ -41,7 +42,10 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -41,7 +42,10 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (IsOpNamed(node, kOptimizerType)) { if (IsOpNamed(node, kOptimizerType)) {
auto& param_out_vars = node->Op()->Output("ParamOut"); auto& param_out_vars = node->Op()->Output("ParamOut");
PADDLE_ENFORCE(param_out_vars.size() == 1u); PADDLE_ENFORCE_EQ(
param_out_vars.size(), 1u,
platform::errors::InvalidArgument(
"In op(%s), find output(ParamOut) failed.", node->Name()));
weight_var_set.insert(param_out_vars[0]); weight_var_set.insert(param_out_vars[0]);
} }
} }
...@@ -95,12 +99,19 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -95,12 +99,19 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Found forward_op " << forward_op->Name(); VLOG(3) << "Found forward_op " << forward_op->Name();
PADDLE_ENFORCE(forward_op); PADDLE_ENFORCE_NOT_NULL(
forward_op, platform::errors::NotFound(
"Can not find forward op for backword op(%s).",
backward_op->Name()));
Node* new_optimizer_node = CreateNewSGDNode( Node* new_optimizer_node = CreateNewSGDNode(
graph, forward_op, backward_op, node, opt_node); graph, forward_op, backward_op, node, opt_node);
PADDLE_ENFORCE(new_optimizer_node); PADDLE_ENFORCE_NOT_NULL(
new_optimizer_node,
platform::errors::InvalidArgument(
"Create new SGD node failed, backward op is %s.",
backward_op->Name()));
} }
} }
} }
...@@ -144,11 +155,21 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -144,11 +155,21 @@ void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
ir::Node* LockFreeOptimizePass::CreateNewSGDNode( ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
ir::Graph* graph, ir::Node* forward_node, ir::Node* backward_node, ir::Graph* graph, ir::Node* forward_node, ir::Node* backward_node,
ir::Node* grad_sum_node, ir::Node* optimize_node) const { ir::Node* grad_sum_node, ir::Node* optimize_node) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(graph,
PADDLE_ENFORCE(forward_node); platform::errors::InvalidArgument(
PADDLE_ENFORCE(backward_node); "Input argument graph cannot be nullptr."));
PADDLE_ENFORCE(grad_sum_node); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(optimize_node); forward_node, platform::errors::InvalidArgument(
"Input argument forward_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
backward_node, platform::errors::InvalidArgument(
"Input argument backward_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
grad_sum_node, platform::errors::InvalidArgument(
"Input argument grad_sum_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
optimize_node, platform::errors::InvalidArgument(
"Input argument optimize_node cannot be nullptr."));
// find the grad var node between the grad sum node and backward_node // find the grad var node between the grad sum node and backward_node
std::vector<ir::Node*> grad_vars = std::vector<ir::Node*> grad_vars =
...@@ -159,7 +180,8 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode( ...@@ -159,7 +180,8 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
grad_node = node; grad_node = node;
} }
} }
PADDLE_ENFORCE(grad_node); PADDLE_ENFORCE_NOT_NULL(grad_node, platform::errors::NotFound(
"Can not find control dep variable."));
// create a new SGD node // create a new SGD node
OpDesc* old_desc = optimize_node->Op(); OpDesc* old_desc = optimize_node->Op();
...@@ -212,8 +234,14 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode( ...@@ -212,8 +234,14 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
} }
// SGD must have only one param and LR in // SGD must have only one param and LR in
PADDLE_ENFORCE(old_desc->Input("LearningRate").size() == 1u); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE(old_desc->Input("Param").size() == 1u); old_desc->Input("LearningRate").size(), 1u,
platform::errors::InvalidArgument(
"In op(%s), find input(LearningRate) failed.", old_desc->Type()));
PADDLE_ENFORCE_EQ(
old_desc->Input("Param").size(), 1u,
platform::errors::InvalidArgument("In op(%s), find input(Param) failed.",
old_desc->Type()));
// LR and weight nodes should be copied // LR and weight nodes should be copied
for (Node* upstream_node : optimize_node->inputs) { for (Node* upstream_node : optimize_node->inputs) {
...@@ -245,9 +273,17 @@ std::vector<ir::Node*> LockFreeOptimizePass::FindConnectedNode( ...@@ -245,9 +273,17 @@ std::vector<ir::Node*> LockFreeOptimizePass::FindConnectedNode(
void LockFreeOptimizePass::ReplaceUpstreamNode( void LockFreeOptimizePass::ReplaceUpstreamNode(
ir::Node* upstream_node, ir::Node* old_optimizer_node, ir::Node* upstream_node, ir::Node* old_optimizer_node,
ir::Node* new_optimizer_node) const { ir::Node* new_optimizer_node) const {
PADDLE_ENFORCE(upstream_node); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(old_optimizer_node); upstream_node, platform::errors::InvalidArgument(
PADDLE_ENFORCE(new_optimizer_node); "Input argument upstream_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
old_optimizer_node,
platform::errors::InvalidArgument(
"Input argument old_optimizer_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
new_optimizer_node,
platform::errors::InvalidArgument(
"Input argument new_optimizer_node cannot be nullptr."));
// Remove the old_optimizer_node from upstream_node's outputs vector // Remove the old_optimizer_node from upstream_node's outputs vector
auto& output_node_vec = upstream_node->outputs; auto& output_node_vec = upstream_node->outputs;
...@@ -268,8 +304,14 @@ void LockFreeOptimizePass::ReplaceUpstreamNode( ...@@ -268,8 +304,14 @@ void LockFreeOptimizePass::ReplaceUpstreamNode(
void LockFreeOptimizePass::ReplaceAllDownstreamNode( void LockFreeOptimizePass::ReplaceAllDownstreamNode(
ir::Node* old_optimizer_node, ir::Node* new_optimizer_node) const { ir::Node* old_optimizer_node, ir::Node* new_optimizer_node) const {
PADDLE_ENFORCE(old_optimizer_node); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(new_optimizer_node); old_optimizer_node,
platform::errors::InvalidArgument(
"Input argument old_optimizer_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
new_optimizer_node,
platform::errors::InvalidArgument(
"Input argument new_optimizer_node cannot be nullptr."));
for (ir::Node* downstream_node : old_optimizer_node->outputs) { for (ir::Node* downstream_node : old_optimizer_node->outputs) {
// Remove the old_optimizer_node from downstream_node's inputs vector // Remove the old_optimizer_node from downstream_node's inputs vector
...@@ -292,8 +334,12 @@ void LockFreeOptimizePass::ReplaceAllDownstreamNode( ...@@ -292,8 +334,12 @@ void LockFreeOptimizePass::ReplaceAllDownstreamNode(
ir::Node* LockFreeOptimizePass::FindForwardOpViaBackwardOp( ir::Node* LockFreeOptimizePass::FindForwardOpViaBackwardOp(
ir::Graph* graph, ir::Node* backward_node) const { ir::Graph* graph, ir::Node* backward_node) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(graph,
PADDLE_ENFORCE(backward_node); platform::errors::InvalidArgument(
"Input argument graph cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
backward_node, platform::errors::InvalidArgument(
"Input argument backward_node cannot be nullptr."));
// strip the suffix _grad of backward_node's name // strip the suffix _grad of backward_node's name
std::string forward_op_name = backward_node->Name(); std::string forward_op_name = backward_node->Name();
......
...@@ -87,34 +87,46 @@ class LockFreeOptimizePass : public Pass { ...@@ -87,34 +87,46 @@ class LockFreeOptimizePass : public Pass {
ir::Node* downstream_node) const; ir::Node* downstream_node) const;
inline bool IsOpNamed(ir::Node* node, const std::string& name) const { inline bool IsOpNamed(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node); PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return node->NodeType() == Node::Type::kOperation && node->Name() == name; return node->NodeType() == Node::Type::kOperation && node->Name() == name;
} }
inline bool IsVarNamed(ir::Node* node, const std::string& name) const { inline bool IsVarNamed(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node); PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return node->NodeType() == Node::Type::kVariable && node->Name() == name; return node->NodeType() == Node::Type::kVariable && node->Name() == name;
} }
inline bool IsVarNameEndsWith(ir::Node* node, const std::string& name) const { inline bool IsVarNameEndsWith(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node); PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return node->NodeType() == Node::Type::kVariable && return node->NodeType() == Node::Type::kVariable &&
boost::algorithm::ends_with(node->Name(), name); boost::algorithm::ends_with(node->Name(), name);
} }
inline bool IsVarNameContains(ir::Node* node, const std::string& name) const { inline bool IsVarNameContains(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node); PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return node->NodeType() == Node::Type::kVariable && return node->NodeType() == Node::Type::kVariable &&
node->Name().find(name) != std::string::npos; node->Name().find(name) != std::string::npos;
} }
inline bool IsControlDepFrom(ir::Node* ctrl_dep_node, ir::Node* node) const { inline bool IsControlDepFrom(ir::Node* ctrl_dep_node, ir::Node* node) const {
PADDLE_ENFORCE(ctrl_dep_node); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE(node); ctrl_dep_node, platform::errors::InvalidArgument(
"Input argument ctrl_dep_node cannot be nullptr."));
PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::InvalidArgument(
"Input argument node cannot be nullptr."));
return IsControlDepVar(*ctrl_dep_node) && return IsControlDepVar(*ctrl_dep_node) &&
ctrl_dep_node->inputs.size() >= 1u && ctrl_dep_node->inputs.size() >= 1u &&
......
...@@ -85,7 +85,9 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const { ...@@ -85,7 +85,9 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const {
// 1. record op nodes of different roles // 1. record op nodes of different roles
for (auto node : nodes) { for (auto node : nodes) {
if (!node->IsOp()) continue; if (!node->IsOp()) continue;
PADDLE_ENFORCE(node->Op(), "must find opdesc"); PADDLE_ENFORCE_NOT_NULL(
node->Op(), platform::errors::InvalidArgument(
"Node(%s) must hold op description.", node->Name()));
int op_role = BOOST_GET_CONST( int op_role = BOOST_GET_CONST(
int, node->Op()->GetAttr( int, node->Op()->GetAttr(
framework::OpProtoAndCheckerMaker::OpRoleAttrName())); framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
...@@ -108,7 +110,9 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const { ...@@ -108,7 +110,9 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const {
} else if (op_role & static_cast<int>(framework::OpRole::kLRSched)) { } else if (op_role & static_cast<int>(framework::OpRole::kLRSched)) {
lr_ops.push_back(node); lr_ops.push_back(node);
} else { // NOLINT } else { // NOLINT
PADDLE_THROW("Invalid op_role: %d", static_cast<int>(op_role)); PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid op role(%d), in node(%s).", static_cast<int>(op_role),
node->Name()));
} }
} }
......
...@@ -66,12 +66,18 @@ class Node { ...@@ -66,12 +66,18 @@ class Node {
std::string Name() const { return name_; } std::string Name() const { return name_; }
VarDesc* Var() const { VarDesc* Var() const {
PADDLE_ENFORCE_EQ(IsVar(), true); PADDLE_ENFORCE_EQ(IsVar(), true,
platform::errors::InvalidArgument(
"Node(%s) must be kVariable type, but not %d.", name_,
static_cast<int>(type_)));
return var_desc_.get(); return var_desc_.get();
} }
OpDesc* Op() const { OpDesc* Op() const {
PADDLE_ENFORCE_EQ(IsOp(), true); PADDLE_ENFORCE_EQ(IsOp(), true,
platform::errors::InvalidArgument(
"Node(%s) must be kOperation type, but not %d.",
name_, static_cast<int>(type_)));
return op_desc_.get(); return op_desc_.get();
} }
...@@ -92,8 +98,9 @@ class Node { ...@@ -92,8 +98,9 @@ class Node {
try { try {
return *boost::any_cast<T*>(wrapper_); return *boost::any_cast<T*>(wrapper_);
} catch (boost::bad_any_cast&) { } catch (boost::bad_any_cast&) {
PADDLE_THROW("Invalid wrapper type error, expected %s, actual %s", PADDLE_THROW(platform::errors::InvalidArgument(
typeid(T).name(), wrapper_type_.name()); "Invalid wrapper type error, expected %s, actual %s.",
typeid(T).name(), wrapper_type_.name()));
} }
} }
...@@ -114,8 +121,9 @@ class Node { ...@@ -114,8 +121,9 @@ class Node {
} }
void RenameVar(const std::string& new_name) { void RenameVar(const std::string& new_name) {
PADDLE_ENFORCE(type_ == Type::kVariable && var_desc_, PADDLE_ENFORCE_EQ(
"Must be type of variable"); type_ == Type::kVariable && var_desc_, true,
platform::errors::InvalidArgument("Node must be type of variable."));
name_ = new_name; name_ = new_name;
var_desc_->SetName(new_name); var_desc_->SetName(new_name);
} }
......
...@@ -19,6 +19,9 @@ limitations under the License. */ ...@@ -19,6 +19,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -26,7 +29,8 @@ namespace ir { ...@@ -26,7 +29,8 @@ namespace ir {
Graph* Pass::Apply(Graph* graph) const { Graph* Pass::Apply(Graph* graph) const {
CheckPrevPass(); CheckPrevPass();
PADDLE_ENFORCE(graph, "graph passed to Pass::Apply() cannot be empty."); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
for (const std::string& attr : required_pass_attrs_) { for (const std::string& attr : required_pass_attrs_) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
attrs_.find(attr), attrs_.end(), attrs_.find(attr), attrs_.end(),
...@@ -40,11 +44,14 @@ Graph* Pass::Apply(Graph* graph) const { ...@@ -40,11 +44,14 @@ Graph* Pass::Apply(Graph* graph) const {
} }
ApplyImpl(graph); ApplyImpl(graph);
// TODO(panyx0718): Add more verifications. // TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE(!HasCircle(*graph), PADDLE_ENFORCE_EQ(
"Illegal Pass %s. Generated graph shouldn't have cycle.", HasCircle(*graph), false,
Type()); platform::errors::InvalidArgument(
PADDLE_ENFORCE(VarDescIsConsistency(*graph), "Illegal pass %s. Generated graph shouldn't contain cycle.", Type()));
"The VarDescs of persistable variable are not consistency."); PADDLE_ENFORCE_EQ(
VarDescIsConsistency(*graph), true,
platform::errors::InvalidArgument(
"The VarDescs of persistable variable are not consistency."));
applied_ = true; applied_ = true;
if (!graph->Has(kPassRecorder)) { if (!graph->Has(kPassRecorder)) {
graph->Set<PassRecorder>(kPassRecorder, new PassRecorder); graph->Set<PassRecorder>(kPassRecorder, new PassRecorder);
...@@ -53,10 +60,7 @@ Graph* Pass::Apply(Graph* graph) const { ...@@ -53,10 +60,7 @@ Graph* Pass::Apply(Graph* graph) const {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache, // Clear mkl-dnn cache,
// Passes can change params, tensors, so caching need to be discarded // Passes can change params, tensors, so caching need to be discarded
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); ClearMKLDNNCache(paddle::platform::CPUPlace());
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(paddle::platform::CPUPlace());
dev_ctx->ResetBlobMap();
#endif #endif
return graph; return graph;
} }
......
...@@ -55,8 +55,9 @@ class Pass { ...@@ -55,8 +55,9 @@ class Pass {
// Get a reference to the attributed previously set. // Get a reference to the attributed previously set.
template <typename AttrType> template <typename AttrType>
AttrType &Get(const std::string &attr_name) const { AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(), PADDLE_ENFORCE_NE(attrs_.find(attr_name), attrs_.end(),
"%s attr not registered for pass.", attr_name); platform::errors::InvalidArgument(
"Attribute %s not registered for pass.", attr_name));
try { try {
return *boost::any_cast<AttrType *>(attrs_.at(attr_name)); return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} catch (boost::bad_any_cast &) { } catch (boost::bad_any_cast &) {
...@@ -76,7 +77,7 @@ class Pass { ...@@ -76,7 +77,7 @@ class Pass {
}; };
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid type for attritube %s, expected: %s, actual: %s", attr_name, "Invalid type for attritube %s, expected: %s, actual: %s.", attr_name,
TypeToString(typeid(AttrType *)), TypeToString(typeid(AttrType *)),
TypeToString(attrs_.at(attr_name).type()))); TypeToString(attrs_.at(attr_name).type())));
} }
...@@ -101,9 +102,10 @@ class Pass { ...@@ -101,9 +102,10 @@ class Pass {
template <typename AttrType> template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) { void Set(const std::string &attr_name, AttrType *attr) {
if (default_pass_attrs_.count(attr_name) == 0) { if (default_pass_attrs_.count(attr_name) == 0) {
PADDLE_ENFORCE_EQ(attrs_.count(attr_name), 0, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( attrs_.count(attr_name), 0,
"Attribute %s already set in the pass", attr_name)); platform::errors::AlreadyExists(
"Attribute %s already set in the pass.", attr_name));
} else { } else {
VLOG(3) << "Setting the attribute " << attr_name << " for the pass " VLOG(3) << "Setting the attribute " << attr_name << " for the pass "
<< type_; << type_;
...@@ -119,15 +121,16 @@ class Pass { ...@@ -119,15 +121,16 @@ class Pass {
// should delete the attribute. // should delete the attribute.
template <typename AttrType> template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr) { void SetNotOwned(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the pass", PADDLE_ENFORCE_EQ(attrs_.count(attr_name), 0,
attr_name); platform::errors::AlreadyExists(
"Attribute %s already set in the pass.", attr_name));
attrs_[attr_name] = attr; attrs_[attr_name] = attr;
} }
protected: protected:
virtual void ApplyImpl(Graph *graph) const { virtual void ApplyImpl(Graph *graph) const {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"The virtual Pass called is not implemented.")); "The virtual pass called is not implemented."));
} }
// Some Pass must be placed before this Pass, and some // Some Pass must be placed before this Pass, and some
...@@ -198,8 +201,9 @@ class PassRegistry { ...@@ -198,8 +201,9 @@ class PassRegistry {
} }
std::unique_ptr<Pass> Get(const std::string &pass_type) const { std::unique_ptr<Pass> Get(const std::string &pass_type) const {
PADDLE_ENFORCE(Has(pass_type), "Pass %s has not been registered", PADDLE_ENFORCE_EQ(Has(pass_type), true,
pass_type); platform::errors::InvalidArgument(
"Pass %s has not been registered.", pass_type));
return map_.at(pass_type)(); return map_.at(pass_type)();
} }
...@@ -213,8 +217,10 @@ class PassRegistry { ...@@ -213,8 +217,10 @@ class PassRegistry {
template <typename PassType> template <typename PassType>
struct PassRegistrar : public Registrar { struct PassRegistrar : public Registrar {
explicit PassRegistrar(const char *pass_type) { explicit PassRegistrar(const char *pass_type) {
PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type), PADDLE_ENFORCE_EQ(
"'%s' is registered more than once.", pass_type); PassRegistry::Instance().Has(pass_type), false,
platform::errors::AlreadyExists(
"Pass '%s' is registered more than once.", pass_type));
PassRegistry::Instance().Insert( PassRegistry::Instance().Insert(
pass_type, [this, pass_type]() -> std::unique_ptr<Pass> { pass_type, [this, pass_type]() -> std::unique_ptr<Pass> {
std::unique_ptr<Pass> pass(new PassType()); std::unique_ptr<Pass> pass(new PassType());
......
...@@ -28,13 +28,19 @@ std::shared_ptr<Pass> PassBuilder::AppendPass(const std::string& pass_type) { ...@@ -28,13 +28,19 @@ std::shared_ptr<Pass> PassBuilder::AppendPass(const std::string& pass_type) {
} }
void PassBuilder::RemovePass(size_t idx) { void PassBuilder::RemovePass(size_t idx) {
PADDLE_ENFORCE(passes_.size() > idx); PADDLE_ENFORCE_GT(
passes_.size(), idx,
platform::errors::InvalidArgument(
"Passes size is %d, %d is not a valid index.", passes_.size(), idx));
passes_.erase(passes_.begin() + idx); passes_.erase(passes_.begin() + idx);
} }
std::shared_ptr<Pass> PassBuilder::InsertPass(size_t idx, std::shared_ptr<Pass> PassBuilder::InsertPass(size_t idx,
const std::string& pass_type) { const std::string& pass_type) {
PADDLE_ENFORCE(passes_.size() >= idx); PADDLE_ENFORCE_GE(
passes_.size(), idx,
platform::errors::InvalidArgument(
"Passes size is %d, %d is not a valid index.", passes_.size(), idx));
std::shared_ptr<Pass> pass( std::shared_ptr<Pass> pass(
ir::PassRegistry::Instance().Get(pass_type).release()); ir::PassRegistry::Instance().Get(pass_type).release());
passes_.insert(passes_.begin() + idx, std::move(pass)); passes_.insert(passes_.begin() + idx, std::move(pass));
......
...@@ -119,7 +119,7 @@ TEST(PassTest, TestPassAttrCheck) { ...@@ -119,7 +119,7 @@ TEST(PassTest, TestPassAttrCheck) {
} catch (paddle::platform::EnforceNotMet& e) { } catch (paddle::platform::EnforceNotMet& e) {
exception = std::string(e.what()); exception = std::string(e.what());
} }
ASSERT_TRUE(exception.find("shouldn't have cycle") != exception.npos); ASSERT_TRUE(exception.find("shouldn't contain cycle") != exception.npos);
pass = PassRegistry::Instance().Get("test_pass"); pass = PassRegistry::Instance().Get("test_pass");
pass->Set<int>("test_pass_attr", new int); pass->Set<int>("test_pass_attr", new int);
......
...@@ -43,9 +43,11 @@ void DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -43,9 +43,11 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
// ops linked from it // ops linked from it
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
PADDLE_ENFORCE_EQ(subgraph.count(input_act_node), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( subgraph.count(input_act_node), true,
"Input act node not found in Delete Quant fusion.")); platform::errors::NotFound(
"Input act node(%s) not found in QuantDequantFuse pass.",
input_act_node->name()));
Node* input_act = subgraph.at(input_act_node); Node* input_act = subgraph.at(input_act_node);
Node* input_scale = subgraph.at(pattern.GetPDNode("input_scale_node")); Node* input_scale = subgraph.at(pattern.GetPDNode("input_scale_node"));
Node* quant = subgraph.at(pattern.GetPDNode("quant_node")); Node* quant = subgraph.at(pattern.GetPDNode("quant_node"));
...@@ -58,7 +60,7 @@ void DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -58,7 +60,7 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
std::string input_scale_var_name = quant->Op()->Input("InScale").front(); std::string input_scale_var_name = quant->Op()->Input("InScale").front();
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument( scope, platform::errors::InvalidArgument(
"scope in DeleteQuantOpFuse pass should not be null.")); "Scope in QuantDequantFuse pass should not be null."));
const LoDTensor& input_scale_tensor = const LoDTensor& input_scale_tensor =
scope->FindVar(input_scale_var_name)->Get<LoDTensor>(); scope->FindVar(input_scale_var_name)->Get<LoDTensor>();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -84,8 +86,8 @@ void DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -84,8 +86,8 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
} else if (quantized_op_type == "mul") { } else if (quantized_op_type == "mul") {
op_desc->SetAttr("X_scale", scale_value); op_desc->SetAttr("X_scale", scale_value);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported quantized op type %s", quantized_op_type)); "Unsupported quantized op type %s.", quantized_op_type));
} }
op_desc->SetAttr("bit_length", bit_length); op_desc->SetAttr("bit_length", bit_length);
op_desc->RenameInput(output_act_name, input_act_name); op_desc->RenameInput(output_act_name, input_act_name);
...@@ -119,9 +121,9 @@ void FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -119,9 +121,9 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
weight_name = "W"; weight_name = "W";
input_name = "Input"; input_name = "Input";
} else { } else {
PADDLE_ENFORCE( PADDLE_THROW(platform::errors::Unimplemented(
"QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for " "QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for "
"now."); "now."));
} }
const std::string pattern_name = "dequant_fuse"; const std::string pattern_name = "dequant_fuse";
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -141,8 +143,9 @@ void FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -141,8 +143,9 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
Graph* g) { Graph* g) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
subgraph.count(quantized_op_input), true, subgraph.count(quantized_op_input), true,
platform::errors::NotFound( platform::errors::NotFound("Quantized op input node(%s) did not find "
"Quantized op input node not found in Delete Quant fusion.")); "in QuantDequantFuse pass.",
quantized_op_input->name()));
Node* quantized_op_input_node = subgraph.at(quantized_op_input); Node* quantized_op_input_node = subgraph.at(quantized_op_input);
Node* quantized_op_weight_node = Node* quantized_op_weight_node =
subgraph.at(pattern.GetPDNode("quantized_op_weight")); subgraph.at(pattern.GetPDNode("quantized_op_weight"));
...@@ -165,7 +168,7 @@ void FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -165,7 +168,7 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scales_name.size(), 2, scales_name.size(), 2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Scales size in channel-wise dequantize op should be 2, got %d", "Scales size in channel-wise dequantize op should be 2, got %d.",
scales_name.size())); scales_name.size()));
const LoDTensor& channel_scale_tensor = const LoDTensor& channel_scale_tensor =
scope->FindVar(scales_name[0])->Get<LoDTensor>(); scope->FindVar(scales_name[0])->Get<LoDTensor>();
...@@ -193,9 +196,10 @@ void FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -193,9 +196,10 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
bool valid_scale_size = bool valid_scale_size =
(weight_scale.size() == 1 || (weight_scale.size() == 1 ||
weight_scale.size() == static_cast<size_t>(w_dims[0])); weight_scale.size() == static_cast<size_t>(w_dims[0]));
PADDLE_ENFORCE_EQ(valid_scale_size, true, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( valid_scale_size, true,
"TRT int8 quant: invalid scale size")); platform::errors::InvalidArgument(
"TRT int8 quant: invalid scale size(%d).", weight_scale.size()));
float* quantized_weight_data = float* quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace()); weight_tensor->mutable_data<float>(platform::CPUPlace());
for (int j = 0; j < weight_tensor->numel(); j++) { for (int j = 0; j < weight_tensor->numel(); j++) {
......
...@@ -278,11 +278,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -278,11 +278,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto retrieve_node = [](const std::string& name, auto retrieve_node = [](const std::string& name,
const GraphPatternDetector::subgraph_t& subgraph, const GraphPatternDetector::subgraph_t& subgraph,
const PDPattern& pat) -> Node* { const PDPattern& pat) -> Node* {
PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)), PADDLE_ENFORCE_GT(subgraph.count(pat.RetrieveNode(name)), 0,
"pattern has no Node called %s", name.c_str()); platform::errors::NotFound(
"Pattern has no node called %s.", name.c_str()));
Node* p = subgraph.at(pat.RetrieveNode(name)); Node* p = subgraph.at(pat.RetrieveNode(name));
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(p, platform::errors::NotFound(
p, platform::errors::NotFound("subgraph has no node %s", name.c_str())); "Subgraph has no node %s.", name.c_str()));
return p; return p;
}; };
...@@ -365,7 +366,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -365,7 +366,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
} }
void RepeatedFCReluFusePass::ApplyImpl(ir::Graph* graph) const { void RepeatedFCReluFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph); PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
int fusion_count = 0; int fusion_count = 0;
......
...@@ -55,9 +55,15 @@ void TestMain(int num_fc) { ...@@ -55,9 +55,15 @@ void TestMain(int num_fc) {
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
// Delete (num_fc_nodes_before - 1) fc ops // Delete (num_fc_nodes_before - 1) fc ops
PADDLE_ENFORCE_EQ(num_nodes_before - (num_fc_nodes_before - 1) + 1, PADDLE_ENFORCE_EQ(
num_nodes_after); num_nodes_before - (num_fc_nodes_before - 1) + 1, num_nodes_after,
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1); platform::errors::InvalidArgument(
"num_nodes_before = %d, num_fc_nodes_before = %d, num_nodes_after = "
"%d.",
num_nodes_before, num_fc_nodes_before, num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1,
platform::errors::InvalidArgument(
"num_fused_nodes_after = %d.", num_fused_nodes_after));
} }
TEST(RepeatedFCReluFusePass, basic_3) { TestMain(3); } TEST(RepeatedFCReluFusePass, basic_3) { TestMain(3); }
......
...@@ -185,11 +185,13 @@ void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -185,11 +185,13 @@ void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const {
auto* concat_out = BuildSeqExpandConcatPattern(pattern); auto* concat_out = BuildSeqExpandConcatPattern(pattern);
BuildFCPattern(pattern, concat_out); BuildFCPattern(pattern, concat_out);
#define GET_NODE(id, pattern) \ #define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetrieveNode(#id)), \ PADDLE_ENFORCE_GT( \
"pattern has no Node called %s", #id); \ subgraph.count(pattern.RetrieveNode(#id)), 0, \
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \ platform::errors::NotFound("Pattern has no node called %s.", #id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL( \
id, platform::errors::NotFound("Subgraph has no node %s.", #id));
int fuse_count{0}; int fuse_count{0};
......
...@@ -139,11 +139,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -139,11 +139,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto retrieve_node = [](const std::string& name, auto retrieve_node = [](const std::string& name,
const GraphPatternDetector::subgraph_t& subgraph, const GraphPatternDetector::subgraph_t& subgraph,
const PDPattern& pat) -> Node* { const PDPattern& pat) -> Node* {
PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)), PADDLE_ENFORCE_GT(subgraph.count(pat.RetrieveNode(name)), 0,
"pattern has no Node called %s", name.c_str()); platform::errors::NotFound(
"Pattern has no node called %s.", name.c_str()));
Node* p = subgraph.at(pat.RetrieveNode(name)); Node* p = subgraph.at(pat.RetrieveNode(name));
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(p, platform::errors::NotFound(
p, platform::errors::NotFound("subgraph has no node %s", name.c_str())); "Subgraph has no node %s.", name.c_str()));
return p; return p;
}; };
......
...@@ -47,7 +47,9 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { ...@@ -47,7 +47,9 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
Graph* g) { Graph* g) {
GET_NODES; GET_NODES;
PADDLE_ENFORCE(subgraph.count(x)); PADDLE_ENFORCE_GT(
subgraph.count(x), 0,
platform::errors::NotFound("Detector did not find input X."));
auto* input_node = subgraph.at(x); auto* input_node = subgraph.at(x);
auto reshape1_desc = reshape1_op->Op(); auto reshape1_desc = reshape1_op->Op();
auto reshape2_desc = reshape2_op->Op(); auto reshape2_desc = reshape2_op->Op();
......
...@@ -59,12 +59,25 @@ TEST(SimplifyWithBasicOpsPass, dropout) { ...@@ -59,12 +59,25 @@ TEST(SimplifyWithBasicOpsPass, dropout) {
int num_scale_nodes_after = GetNumOpNodes(graph, "scale"); int num_scale_nodes_after = GetNumOpNodes(graph, "scale");
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_dropout_nodes_after, 0); PADDLE_ENFORCE_EQ(
num_dropout_nodes_after, 0,
platform::errors::InvalidArgument("num_dropout_nodes_after = %d.",
num_dropout_nodes_after));
if (dropout_implementation == "downgrade_in_infer") { if (dropout_implementation == "downgrade_in_infer") {
PADDLE_ENFORCE_EQ(num_dropout_nodes_before, PADDLE_ENFORCE_EQ(
num_scale_nodes_after - num_scale_nodes_before); num_dropout_nodes_before,
num_scale_nodes_after - num_scale_nodes_before,
platform::errors::InvalidArgument(
"num_dropout_nodes_before = %d, num_scale_nodes_after = %d, "
"num_scale_nodes_before = %d.",
num_dropout_nodes_before, num_scale_nodes_after,
num_scale_nodes_before));
} else { } else {
PADDLE_ENFORCE_EQ(num_scale_nodes_after - num_scale_nodes_before, 0); PADDLE_ENFORCE_EQ(
num_scale_nodes_after - num_scale_nodes_before, 0,
platform::errors::InvalidArgument(
"num_scale_nodes_after = %d, num_scale_nodes_before = %d.",
num_scale_nodes_after, num_scale_nodes_before));
} }
} }
} }
......
...@@ -300,10 +300,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -300,10 +300,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
auto retrieve_node = [](const std::string& name, auto retrieve_node = [](const std::string& name,
const GraphPatternDetector::subgraph_t& subgraph, const GraphPatternDetector::subgraph_t& subgraph,
const PDPattern& pat) -> Node* { const PDPattern& pat) -> Node* {
PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)), PADDLE_ENFORCE_GT(subgraph.count(pat.RetrieveNode(name)), 0,
"pattern has no Node called %s", name.c_str()); platform::errors::NotFound(
"Pattern has no node called %s.", name.c_str()));
Node* p = subgraph.at(pat.RetrieveNode(name)); Node* p = subgraph.at(pat.RetrieveNode(name));
PADDLE_ENFORCE_NOT_NULL(p, "subgraph has no node %s", name.c_str()); PADDLE_ENFORCE_NOT_NULL(p, platform::errors::NotFound(
"Subgraph has no node %s.", name.c_str()));
return p; return p;
}; };
......
...@@ -51,15 +51,25 @@ void RunTransposeFlattenConcatFuse(ir::Graph *graph, int times) { ...@@ -51,15 +51,25 @@ void RunTransposeFlattenConcatFuse(ir::Graph *graph, int times) {
std::vector<Node *> nodes; std::vector<Node *> nodes;
for (int i = 0; i < times; i++) { for (int i = 0; i < times; i++) {
PADDLE_ENFORCE( PADDLE_ENFORCE_NOT_NULL(
subgraph.at(pattern.GetPDNode("transpose" + std::to_string(i)))); subgraph.at(pattern.GetPDNode("transpose" + std::to_string(i))),
PADDLE_ENFORCE( platform::errors::NotFound("Can not find transpose%d in subgraph.",
subgraph.at(pattern.GetPDNode("transpose_out" + std::to_string(i)))); i));
PADDLE_ENFORCE( PADDLE_ENFORCE_NOT_NULL(
subgraph.at(pattern.GetPDNode("flatten" + std::to_string(i)))); subgraph.at(pattern.GetPDNode("transpose_out" + std::to_string(i))),
PADDLE_ENFORCE( platform::errors::NotFound(
subgraph.at(pattern.GetPDNode("flatten_out" + std::to_string(i)))); "Can not find transpose_out%d in subgraph.", i));
PADDLE_ENFORCE(subgraph.at(input_nodes[i])); PADDLE_ENFORCE_NOT_NULL(
subgraph.at(pattern.GetPDNode("flatten" + std::to_string(i))),
platform::errors::NotFound("Can not find flatten%d in subgraph.", i));
PADDLE_ENFORCE_NOT_NULL(
subgraph.at(pattern.GetPDNode("flatten_out" + std::to_string(i))),
platform::errors::NotFound("Can not find flatten_out%d in subgraph.",
i));
PADDLE_ENFORCE_NOT_NULL(
subgraph.at(input_nodes[i]),
platform::errors::NotFound("Can not find %s in subgraph.",
input_nodes[i]->name()));
nodes.push_back(subgraph.at(input_nodes[i])); nodes.push_back(subgraph.at(input_nodes[i]));
nodes.push_back( nodes.push_back(
......
...@@ -37,7 +37,10 @@ inline std::string LibraryTypeToString(const LibraryType& library_type) { ...@@ -37,7 +37,10 @@ inline std::string LibraryTypeToString(const LibraryType& library_type) {
case LibraryType::kCUDNN: case LibraryType::kCUDNN:
return "CUDNN"; return "CUDNN";
default: default:
PADDLE_THROW("unknown LibraryType %d", static_cast<int>(library_type)); PADDLE_THROW(platform::errors::Unimplemented(
"Unknown LibraryType code (%d), only supports library type include "
"PLAIN(0), MKLDNN(1), CUDNN(2).",
static_cast<int>(library_type)));
} }
} }
...@@ -59,7 +62,10 @@ inline LibraryType StringToLibraryType(const char* ctype) { ...@@ -59,7 +62,10 @@ inline LibraryType StringToLibraryType(const char* ctype) {
} else if (s == std::string("CUDA")) { } else if (s == std::string("CUDA")) {
return LibraryType::kPlain; return LibraryType::kPlain;
} else { } else {
PADDLE_THROW("Unknown LibraryType %s", s.c_str()); PADDLE_THROW(platform::errors::Unimplemented(
"Unknown LibraryType string (%s), only support library type string "
"include PLAIN, MKLDNN, CUDNN, CPU and CUDA.",
s.c_str()));
} }
} }
......
...@@ -35,7 +35,10 @@ T *DynLoad(void *handle, std::string name) { ...@@ -35,7 +35,10 @@ T *DynLoad(void *handle, std::string name) {
#else #else
auto errorno = GetLastError(); auto errorno = GetLastError();
#endif // !_WIN32 #endif // !_WIN32
PADDLE_ENFORCE_NOT_NULL(func, errorno); PADDLE_ENFORCE_NOT_NULL(
func,
platform::errors::NotFound(
"Failed to load dynamic operator library, error code(%s).", errorno));
return func; return func;
} }
...@@ -63,9 +66,9 @@ void LoadOpLib(const std::string &dso_name) { ...@@ -63,9 +66,9 @@ void LoadOpLib(const std::string &dso_name) {
type == "conditional_block" || type == "conditional_block_grad") { type == "conditional_block" || type == "conditional_block_grad") {
continue; continue;
} }
if (info_map.Has(n.first)) { PADDLE_ENFORCE_NE(info_map.Has(n.first), true,
PADDLE_THROW("Op %s has been registered."); platform::errors::AlreadyExists(
} "Operator (%s) has been registered.", type));
OpInfo info; OpInfo info;
info.creator_ = n.second.creator_; info.creator_ = n.second.creator_;
...@@ -88,7 +91,8 @@ void LoadOpLib(const std::string &dso_name) { ...@@ -88,7 +91,8 @@ void LoadOpLib(const std::string &dso_name) {
for (auto &str : strs) { for (auto &str : strs) {
proto::OpDesc proto_desc; proto::OpDesc proto_desc;
PADDLE_ENFORCE_EQ(proto_desc.ParseFromString(str), true, PADDLE_ENFORCE_EQ(proto_desc.ParseFromString(str), true,
"Failed to parse OpDesc from string"); platform::errors::InvalidArgument(
"Failed to parse OpDesc from string."));
ret.emplace_back(new OpDesc(proto_desc, nullptr)); ret.emplace_back(new OpDesc(proto_desc, nullptr));
} }
return ret; return ret;
......
...@@ -19,9 +19,11 @@ namespace framework { ...@@ -19,9 +19,11 @@ namespace framework {
void LoDRankTable::Reset(const LoD& lod, size_t level) { void LoDRankTable::Reset(const LoD& lod, size_t level) {
this->coarse_lod_.clear(); this->coarse_lod_.clear();
this->items_.clear(); this->items_.clear();
PADDLE_ENFORCE(level < lod.size(), PADDLE_ENFORCE_LT(
"Cannot rank lod since the level %d is less than lod size %d", level, lod.size(),
level, lod.size()); platform::errors::InvalidArgument(
"Cannot reset LoD since the level %d is less than lod size %d.",
level, lod.size()));
coarse_lod_.reserve(level); coarse_lod_.reserve(level);
for (size_t i = 0; i < level; ++i) { for (size_t i = 0; i < level; ++i) {
coarse_lod_.push_back(lod[i]); coarse_lod_.push_back(lod[i]);
......
...@@ -65,9 +65,23 @@ std::string LoDToString(const LoD &lod) { ...@@ -65,9 +65,23 @@ std::string LoDToString(const LoD &lod) {
LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin, LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin,
size_t elem_end) { size_t elem_end) {
PADDLE_ENFORCE_LT(level, in.size()); PADDLE_ENFORCE_LT(level, in.size(),
PADDLE_ENFORCE_LT(elem_begin, elem_end); platform::errors::InvalidArgument(
PADDLE_ENFORCE_LT(elem_end, in[level].size()); "The input LoDTensor's lod level should be less than "
"the LoD size, but received level is %d, LoD is %s.",
level, in));
PADDLE_ENFORCE_LT(
elem_begin, elem_end,
platform::errors::InvalidArgument(
"The index to start slicing should be less than the index to end "
"slicing, but received start index is %d, end index is %d.",
elem_begin, elem_end));
PADDLE_ENFORCE_LT(
elem_end, in[level].size(),
platform::errors::InvalidArgument(
"The index to end slicing should be less than the input LoD size, "
"but received end index is %d, LoD size is %d.",
elem_end, in[level].size()));
LoD res; LoD res;
res.resize(in.size() - level); res.resize(in.size() - level);
...@@ -185,8 +199,17 @@ LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD &lod, size_t start_idx, ...@@ -185,8 +199,17 @@ LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD &lod, size_t start_idx,
LoD sub_lod; LoD sub_lod;
for (size_t level_idx = start_level; level_idx < lod.size(); ++level_idx) { for (size_t level_idx = start_level; level_idx < lod.size(); ++level_idx) {
PADDLE_ENFORCE_LE(start_idx, end_idx); PADDLE_ENFORCE_LE(start_idx, end_idx,
PADDLE_ENFORCE_LT(end_idx, lod[level_idx].size()); platform::errors::InvalidArgument(
"The start index should be less than the end index, "
"but received start index is %d, end index is %d.",
start_idx, end_idx));
PADDLE_ENFORCE_LT(
end_idx, lod[level_idx].size(),
platform::errors::InvalidArgument(
"The end index should be less than the LoD level size, but "
"received end index is %d, LoD level size is %d.",
end_idx, lod[level_idx].size()));
std::vector<size_t> level_lens; std::vector<size_t> level_lens;
for (size_t i = start_idx; i < end_idx; ++i) { for (size_t i = start_idx; i < end_idx; ++i) {
level_lens.push_back(lod[level_idx][i + 1] - lod[level_idx][i]); level_lens.push_back(lod[level_idx][i + 1] - lod[level_idx][i]);
...@@ -202,7 +225,10 @@ LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD &lod, size_t start_idx, ...@@ -202,7 +225,10 @@ LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD &lod, size_t start_idx,
void AppendLoD(LoD *lod, const LoD &lod_length) { void AppendLoD(LoD *lod, const LoD &lod_length) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
lod->empty() || lod->size() == lod_length.size(), lod->empty() || lod->size() == lod_length.size(),
"The lod_length should has the same size with the appended lod."); platform::errors::InvalidArgument(
"The input LoD length should be equal to the appended LoD size, but "
"received input LoD length is %d, actual LoD size is %d.",
lod_length, lod->size()));
if (lod->empty()) { if (lod->empty()) {
for (size_t i = 0; i < lod_length.size(); ++i) { for (size_t i = 0; i < lod_length.size(); ++i) {
lod->emplace_back(1, 0); // size = 1, value = 0; lod->emplace_back(1, 0); // size = 1, value = 0;
...@@ -254,11 +280,11 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, ...@@ -254,11 +280,11 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
is.read(reinterpret_cast<char *>(&version), sizeof(version)); is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(framework::IsTensorVersionSupported(version), true, PADDLE_ENFORCE_EQ(framework::IsTensorVersionSupported(version), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"tensor version %u is not supported.", version)); "Tensor version %u is not supported.", version));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
version, 0U, version, 0U,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"tensor version %u is not supported, Only version 0 is supported", "Tensor version %u is not supported, only version 0 is supported.",
version)); version));
} }
{ {
...@@ -280,11 +306,11 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, ...@@ -280,11 +306,11 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
is.read(reinterpret_cast<char *>(&version), sizeof(version)); is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(framework::IsTensorVersionSupported(version), true, PADDLE_ENFORCE_EQ(framework::IsTensorVersionSupported(version), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"tensor version %u is not supported.", version)); "Tensor version %u is not supported.", version));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
version, 0U, version, 0U,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"tensor version %u is not supported, Only version 0 is supported", "Tensor version %u is not supported, only version 0 is supported.",
version)); version));
} }
{ {
...@@ -310,7 +336,7 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor( ...@@ -310,7 +336,7 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
const std::vector<platform::Place> places) const { const std::vector<platform::Place> places) const {
PADDLE_ENFORCE_GT(places.size(), 0, PADDLE_ENFORCE_GT(places.size(), 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"place number cannot be empty when splitting")); "Place number cannot be empty when splitting."));
check_memory_size(); check_memory_size();
size_t batch_size = size_t batch_size =
lod().empty() ? static_cast<size_t>(dims()[0]) : lod()[0].size() - 1; lod().empty() ? static_cast<size_t>(dims()[0]) : lod()[0].size() - 1;
...@@ -342,7 +368,9 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor( ...@@ -342,7 +368,9 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
auto end = std::min<size_t>((i + 1) * step_width, batch_size); auto end = std::min<size_t>((i + 1) * step_width, batch_size);
PADDLE_ENFORCE_LT(begin, end, PADDLE_ENFORCE_LT(begin, end,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"begin must be less than end, this may be a bug")); "The begin index must be less than the end index, "
"but received begin index is %d, end index is %d.",
begin, end));
LoDTensor dst; LoDTensor dst;
if (lod().empty()) { if (lod().empty()) {
...@@ -376,7 +404,9 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor( ...@@ -376,7 +404,9 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
void LoDTensor::MergeLoDTensor( void LoDTensor::MergeLoDTensor(
const std::vector<const LoDTensor *> &lod_tensors, const std::vector<const LoDTensor *> &lod_tensors,
platform::Place dst_place) { platform::Place dst_place) {
PADDLE_ENFORCE(!lod_tensors.empty()); PADDLE_ENFORCE_EQ(lod_tensors.empty(), false,
platform::errors::InvalidArgument(
"The LoDTensors to be merged are empty."));
framework::DDim new_dim = lod_tensors[0]->dims(); framework::DDim new_dim = lod_tensors[0]->dims();
proto::VarType::Type new_type = proto::VarType::FP32; proto::VarType::Type new_type = proto::VarType::FP32;
...@@ -395,15 +425,35 @@ void LoDTensor::MergeLoDTensor( ...@@ -395,15 +425,35 @@ void LoDTensor::MergeLoDTensor(
for (size_t i = 1; i < lod_tensors.size(); ++i) { for (size_t i = 1; i < lod_tensors.size(); ++i) {
auto *t = lod_tensors[i]; auto *t = lod_tensors[i];
if (t->numel() && t->IsInitialized()) { if (t->numel() && t->IsInitialized()) {
PADDLE_ENFORCE_EQ(new_type, t->type()); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(new_layout, t->layout()); new_type, t->type(),
PADDLE_ENFORCE_EQ(framework::product(new_dim) / new_dim[0], platform::errors::InvalidArgument(
framework::product(t->dims()) / t->dims()[0]); "LoDTensor data type does not match, expected type is %s, actual "
"type is %s.",
DataTypeToString(new_type), DataTypeToString(t->type())));
PADDLE_ENFORCE_EQ(
new_layout, t->layout(),
platform::errors::InvalidArgument(
"LoDTensor layout does not match, expected layout is %s, "
"actual layout is %s.",
DataLayoutToString(new_layout), DataLayoutToString(t->layout())));
PADDLE_ENFORCE_EQ(
framework::product(new_dim) / new_dim[0],
framework::product(t->dims()) / t->dims()[0],
platform::errors::InvalidArgument(
"LoDTensor dimension does not match, all dimensions except the "
"first dimension need to be equal,"
"but expected dimension is %s, actual dimension is %s.",
new_dim, t->dims()));
new_dim[0] += t->dims()[0]; new_dim[0] += t->dims()[0];
} }
auto &lod = t->lod(); auto &lod = t->lod();
PADDLE_ENFORCE_EQ(new_lod.size(), lod.size()); PADDLE_ENFORCE_EQ(new_lod.size(), lod.size(),
platform::errors::InvalidArgument(
"The LoD information of LoDTensor does not match, "
"expected LoD is %s, actual LoD is %s.",
new_lod, lod));
for (size_t j = 0; j < lod.size(); ++j) { for (size_t j = 0; j < lod.size(); ++j) {
auto &sub_lod = new_lod[j]; auto &sub_lod = new_lod[j];
size_t offset = sub_lod.back(); size_t offset = sub_lod.back();
......
...@@ -117,8 +117,19 @@ class LoDTensor : public Tensor { ...@@ -117,8 +117,19 @@ class LoDTensor : public Tensor {
* Get the start offset and end offset of an element from LoD. * Get the start offset and end offset of an element from LoD.
*/ */
std::pair<size_t, size_t> lod_element(size_t level, size_t elem) const { std::pair<size_t, size_t> lod_element(size_t level, size_t elem) const {
PADDLE_ENFORCE_LT(level, NumLevels()); PADDLE_ENFORCE_LT(
PADDLE_ENFORCE_LT(elem, NumElements(level)); level, NumLevels(),
platform::errors::InvalidArgument(
"The input level of LoD is invalid, it should be less than LoD "
"size. The input level is %zu, the LoD size is %zu.",
level, NumLevels()));
PADDLE_ENFORCE_LT(elem, NumElements(level),
platform::errors::InvalidArgument(
"The input element of LoD is invalid, it should be "
"less than the number of elements in its level."
"The input element is %zu, the number of elements in "
"its level is %zu.",
elem, NumElements(level)));
return std::make_pair((lod_)[level][elem], (lod_)[level][elem + 1]); return std::make_pair((lod_)[level][elem], (lod_)[level][elem + 1]);
} }
...@@ -131,7 +142,12 @@ class LoDTensor : public Tensor { ...@@ -131,7 +142,12 @@ class LoDTensor : public Tensor {
* Number of elements in a level. * Number of elements in a level.
*/ */
size_t NumElements(size_t level = 0) const { size_t NumElements(size_t level = 0) const {
PADDLE_ENFORCE_LT(level, NumLevels()); PADDLE_ENFORCE_LT(
level, NumLevels(),
platform::errors::InvalidArgument(
"The input level of LoD is invalid, it should be less than LoD "
"size. The input level is %zu, the LoD size is %zu.",
level, NumLevels()));
// the last offset is the end of last element // the last offset is the end of last element
return (lod_)[level].size() - 1; return (lod_)[level].size() - 1;
} }
...@@ -172,7 +188,13 @@ LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level, ...@@ -172,7 +188,13 @@ LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level,
tensor.Resize(dims); tensor.Resize(dims);
tensor.mutable_data<T>(place); tensor.mutable_data<T>(place);
PADDLE_ENFORCE_EQ(num_instances, lod_level.size() - 1); PADDLE_ENFORCE_EQ(
num_instances, lod_level.size() - 1,
platform::errors::InvalidArgument(
"The input LoDTensor instance number should be equal to the LoD "
"level size minus 1."
"The input instance number is %zu, LoD level size is %zu.",
num_instances, lod_level.size()));
for (size_t ins = 0; ins < num_instances; ins++) { for (size_t ins = 0; ins < num_instances; ins++) {
for (size_t elem = lod_level[ins]; elem < lod_level[ins + 1]; elem++) { for (size_t elem = lod_level[ins]; elem < lod_level[ins + 1]; elem++) {
auto slice = tensor.Slice(elem, elem + 1); auto slice = tensor.Slice(elem, elem + 1);
......
...@@ -155,8 +155,10 @@ class Vector { ...@@ -155,8 +155,10 @@ class Vector {
// get cuda ptr. immutable // get cuda ptr. immutable
const T *CUDAData(platform::Place place) const { const T *CUDAData(platform::Place place) const {
PADDLE_ENFORCE(platform::is_gpu_place(place), PADDLE_ENFORCE_EQ(
"CUDA Data must on CUDA place"); platform::is_gpu_place(place), true,
platform::errors::Unavailable(
"Place mismatch, CUDA Data must be on CUDA place."));
ImmutableCUDA(place); ImmutableCUDA(place);
return reinterpret_cast<T *>(gpu_->ptr()); return reinterpret_cast<T *>(gpu_->ptr());
} }
...@@ -234,7 +236,8 @@ class Vector { ...@@ -234,7 +236,8 @@ class Vector {
UnsetFlag(kDirty); UnsetFlag(kDirty);
SetFlag(kDataInCUDA); SetFlag(kDataInCUDA);
} else if (IsInCUDA() && !(place == gpu_->place())) { } else if (IsInCUDA() && !(place == gpu_->place())) {
PADDLE_THROW("This situation should not happen"); PADDLE_THROW(
platform::errors::Unavailable("Unexpected data place mismatch."));
// Still dirty // Still dirty
} else { } else {
// Dirty && DataInCUDA && Device is same // Dirty && DataInCUDA && Device is same
...@@ -246,7 +249,8 @@ class Vector { ...@@ -246,7 +249,8 @@ class Vector {
CopyCPUDataToCUDA(place); CopyCPUDataToCUDA(place);
SetFlag(kDataInCUDA); SetFlag(kDataInCUDA);
} else if (!(place == gpu_->place())) { } else if (!(place == gpu_->place())) {
PADDLE_THROW("This situation should not happen."); PADDLE_THROW(
platform::errors::Unavailable("Unexpected data place mismatch."));
} else { } else {
// Not Dirty && DataInCUDA && Device is same // Not Dirty && DataInCUDA && Device is same
// Do nothing. // Do nothing.
...@@ -501,27 +505,29 @@ class CPUVector : public std::vector<T, std::allocator<T>> { ...@@ -501,27 +505,29 @@ class CPUVector : public std::vector<T, std::allocator<T>> {
} }
const T *CUDAData(platform::Place place) const { const T *CUDAData(platform::Place place) const {
PADDLE_THROW( PADDLE_THROW(platform::errors::Unavailable(
"Vector::CUDAData() method is not supported in CPU-only version"); "Vector::CUDAData() method is not supported in CPU-only version."));
} }
T *CUDAMutableData(platform::Place place) { T *CUDAMutableData(platform::Place place) {
PADDLE_THROW( PADDLE_THROW(platform::errors::Unavailable(
"Vector::CUDAMutableData() method is not supported in CPU-only " "Vector::CUDAMutableData() method is not supported in CPU-only "
"version"); "version."));
} }
const T *Data(platform::Place place) const { const T *Data(platform::Place place) const {
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
platform::is_cpu_place(place), platform::is_cpu_place(place), true,
"Vector::Data() method is not supported when not in CPUPlace"); platform::errors::Unavailable(
"Vector::Data() method is not supported when not in CPUPlace."));
return this->data(); return this->data();
} }
T *MutableData(platform::Place place) { T *MutableData(platform::Place place) {
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
platform::is_cpu_place(place), platform::is_cpu_place(place), true,
"Vector::MutableData() method is not supported when not in CPUPlace"); platform::errors::Unavailable("Vector::MutableData() method is not "
"supported when not in CPUPlace."));
return this->data(); return this->data();
} }
......
...@@ -106,7 +106,7 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program, ...@@ -106,7 +106,7 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
} }
void MultiTrainer::InitOtherEnv(const ProgramDesc& main_program) { void MultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
if (need_dump_field_) { if (need_dump_field_ || need_dump_param_) {
InitDumpEnv(); InitDumpEnv();
} }
VLOG(3) << "init other env done."; VLOG(3) << "init other env done.";
...@@ -133,7 +133,7 @@ void MultiTrainer::Run() { ...@@ -133,7 +133,7 @@ void MultiTrainer::Run() {
} }
void MultiTrainer::Finalize() { void MultiTrainer::Finalize() {
if (need_dump_field_) { if (need_dump_field_ || need_dump_param_) {
FinalizeDumpEnv(); FinalizeDumpEnv();
} }
root_scope_->DropKids(); root_scope_->DropKids();
......
...@@ -25,6 +25,9 @@ ...@@ -25,6 +25,9 @@
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -51,12 +54,16 @@ void NaiveExecutor::Run() { ...@@ -51,12 +54,16 @@ void NaiveExecutor::Run() {
void NaiveExecutor::CreateVariables(const ProgramDesc &desc, int block_id, void NaiveExecutor::CreateVariables(const ProgramDesc &desc, int block_id,
bool persistable, Scope *scope) { bool persistable, Scope *scope) {
PADDLE_ENFORCE_NOT_NULL(scope); PADDLE_ENFORCE_NOT_NULL(scope,
platform::errors::InvalidArgument(
"The Scope to hold variables is nullptr."));
auto &global_block = desc.Block(block_id); auto &global_block = desc.Block(block_id);
const auto *anc = scope; const auto *anc = scope;
PADDLE_ENFORCE(anc->parent() != anc); PADDLE_ENFORCE_NE(
anc->parent(), anc,
platform::errors::InvalidArgument("Input scope should be child scope."));
while (anc->parent()) { while (anc->parent()) {
anc = anc->parent(); anc = anc->parent();
} }
...@@ -101,9 +108,12 @@ void NaiveExecutor::CreateOps(const ProgramDesc &desc, int block_id, ...@@ -101,9 +108,12 @@ void NaiveExecutor::CreateOps(const ProgramDesc &desc, int block_id,
} }
LoDTensor *NaiveExecutor::FindTensor(const std::string &name) { LoDTensor *NaiveExecutor::FindTensor(const std::string &name) {
PADDLE_ENFORCE(scope_, "Need to init scope first"); PADDLE_ENFORCE_NOT_NULL(scope_,
platform::errors::PreconditionNotMet(
"Need to init scope in NaiveExecutor firstly."));
auto *var = scope_->FindVar(name); auto *var = scope_->FindVar(name);
PADDLE_ENFORCE(var, "No variable [%s] in the scope"); PADDLE_ENFORCE_NOT_NULL(var, platform::errors::NotFound(
"No variable [%s] in current scope.", name));
auto *tensor = const_cast<LoDTensor *>(&var->Get<LoDTensor>()); auto *tensor = const_cast<LoDTensor *>(&var->Get<LoDTensor>());
return tensor; return tensor;
} }
...@@ -122,14 +132,7 @@ NaiveExecutor::~NaiveExecutor() { ...@@ -122,14 +132,7 @@ NaiveExecutor::~NaiveExecutor() {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache, // Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working // this is needed to have mkl-dnn unit tests working
if (platform::is_cpu_place(place_)) { ClearMKLDNNCache(place_);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext *dev_ctx =
(platform::MKLDNNDeviceContext *)pool.Get(place_);
dev_ctx->ResetBlobMap();
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW);
}
#endif #endif
} }
......
...@@ -23,8 +23,9 @@ namespace framework { ...@@ -23,8 +23,9 @@ namespace framework {
const Attribute &InferNoNeedBufferVarsContext::GetAttr( const Attribute &InferNoNeedBufferVarsContext::GetAttr(
const std::string &name) const { const std::string &name) const {
auto iter = attrs_.find(name); auto iter = attrs_.find(name);
PADDLE_ENFORCE_EQ(iter != attrs_.end(), true, "Cannot find attribute %s", PADDLE_ENFORCE_NE(
name); iter, attrs_.end(),
platform::errors::NotFound("Cannot find attribute (%s).", name));
return iter->second; return iter->second;
} }
......
...@@ -101,7 +101,10 @@ class InferNoNeedBufferVarsFN { ...@@ -101,7 +101,10 @@ class InferNoNeedBufferVarsFN {
inline const std::unordered_set<std::string> &operator()( inline const std::unordered_set<std::string> &operator()(
const VariableNameMap &inputs, const VariableNameMap &outputs, const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs) const { const AttributeMap &attrs) const {
PADDLE_ENFORCE_NOT_NULL(inferer_); PADDLE_ENFORCE_NOT_NULL(
inferer_,
platform::errors::PreconditionNotMet(
"The `inferer_` of InferNoNeedBufferVarsFN is not initialized."));
StaticGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs); StaticGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
return (*inferer_)(ctx); return (*inferer_)(ctx);
} }
...@@ -110,7 +113,10 @@ class InferNoNeedBufferVarsFN { ...@@ -110,7 +113,10 @@ class InferNoNeedBufferVarsFN {
const imperative::NameVarMap<imperative::VariableWrapper> &inputs, const imperative::NameVarMap<imperative::VariableWrapper> &inputs,
const imperative::NameVarMap<imperative::VariableWrapper> &outputs, const imperative::NameVarMap<imperative::VariableWrapper> &outputs,
const AttributeMap &attrs) const { const AttributeMap &attrs) const {
PADDLE_ENFORCE_NOT_NULL(inferer_); PADDLE_ENFORCE_NOT_NULL(
inferer_,
platform::errors::PreconditionNotMet(
"The `inferer_` of InferNoNeedBufferVarsFN is not initialized."));
DyGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs); DyGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
return (*inferer_)(ctx); return (*inferer_)(ctx);
} }
...@@ -120,8 +126,14 @@ class InferNoNeedBufferVarsFN { ...@@ -120,8 +126,14 @@ class InferNoNeedBufferVarsFN {
inline bool operator!() const { return inferer_ == nullptr; } inline bool operator!() const { return inferer_ == nullptr; }
inline void Reset(const std::shared_ptr<NoNeedBufferVarsInference> &inferer) { inline void Reset(const std::shared_ptr<NoNeedBufferVarsInference> &inferer) {
PADDLE_ENFORCE_NOT_NULL(inferer); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE_EQ(inferer_, nullptr); inferer, platform::errors::InvalidArgument("The input inferer of "
"InferNoNeedBufferVarsFN::"
"Reset is nullptr."));
PADDLE_ENFORCE_EQ(
inferer_, nullptr,
platform::errors::AlreadyExists(
"The `inferer_` of InferNoNeedBufferVarsFN has been initialized."));
inferer_ = inferer; inferer_ = inferer;
} }
......
...@@ -35,26 +35,14 @@ void InsertCallStackInfo(const std::string &type, const AttributeMap &attrs, ...@@ -35,26 +35,14 @@ void InsertCallStackInfo(const std::string &type, const AttributeMap &attrs,
} }
std::ostringstream sout; std::ostringstream sout;
std::ostringstream sout_py_trace;
// Step 1. Construct python call stack string // Step 1. Construct python call stack string
if (callstack) { if (callstack) {
sout_py_trace << "\n------------------------------------------\n"; sout << "\n\n Compile Traceback (most recent call last):";
sout_py_trace << "Python Call Stacks (More useful to users):";
sout_py_trace << "\n------------------------------------------\n";
for (auto &line : *callstack) { for (auto &line : *callstack) {
sout_py_trace << line; sout << "\n " << line;
} }
} }
// Step 2. Insert python traceback into err_str_ // Step 2. Construct final call stack & append error op name
std::size_t found = exception->err_str_.rfind(
"\n----------------------\nError Message "
"Summary:\n----------------------\n");
if (found != std::string::npos) {
exception->err_str_.insert(found, sout_py_trace.str());
} else {
exception->err_str_.append(sout_py_trace.str());
}
// Step 3. Construct final call stack & append error op name
sout << exception->err_str_; sout << exception->err_str_;
sout << " [operator < " << type << " > error]"; sout << " [operator < " << type << " > error]";
exception->err_str_ = sout.str(); exception->err_str_ = sout.str();
......
...@@ -24,9 +24,10 @@ namespace framework { ...@@ -24,9 +24,10 @@ namespace framework {
inline std::vector<int> ConvertStr2Int(const std::string& str_text) { inline std::vector<int> ConvertStr2Int(const std::string& str_text) {
auto vec_text = string::split_string<std::string>(str_text, "."); auto vec_text = string::split_string<std::string>(str_text, ".");
PADDLE_ENFORCE((vec_text.size() == 2 || vec_text.size() == 3), PADDLE_ENFORCE(
"Input[%s] is not a right version format [1.6 or 1.6.0]", (vec_text.size() == 2 || vec_text.size() == 3),
str_text); platform::errors::InvalidArgument(
"Input[%s] is not a right version format [1.6 or 1.6.0].", str_text));
std::vector<int> vec_res; std::vector<int> vec_res;
vec_res.reserve(3); vec_res.reserve(3);
...@@ -49,10 +50,11 @@ inline bool CompareVersion(const std::string& str_first, ...@@ -49,10 +50,11 @@ inline bool CompareVersion(const std::string& str_first,
auto vec_second_version = ConvertStr2Int(str_second); auto vec_second_version = ConvertStr2Int(str_second);
// first version id // first version id
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(vec_first_version.size(), vec_second_version.size(),
vec_first_version.size(), vec_second_version.size(), platform::errors::InvalidArgument(
"version information size not equal, first is [%d] second is [%d]", "Version information size is not equal, the first is "
vec_first_version.size(), vec_second_version.size()); "[%d], the second is [%d].",
vec_first_version.size(), vec_second_version.size()));
for (size_t i = 0; i < vec_first_version.size() - 1; ++i) { for (size_t i = 0; i < vec_first_version.size() - 1; ++i) {
if (vec_first_version[i] != vec_second_version[i]) { if (vec_first_version[i] != vec_second_version[i]) {
......
...@@ -700,7 +700,7 @@ void OpDesc::InferShape(const BlockDesc &block) const { ...@@ -700,7 +700,7 @@ void OpDesc::InferShape(const BlockDesc &block) const {
} }
infer_shape(&ctx); infer_shape(&ctx);
} catch (platform::EnforceNotMet &exception) { } catch (platform::EnforceNotMet &exception) {
framework::InsertCallStackInfo(Type(), attrs_, &exception); framework::AppendErrorOpHint(Type(), &exception);
throw std::move(exception); throw std::move(exception);
} catch (...) { } catch (...) {
std::rethrow_exception(std::current_exception()); std::rethrow_exception(std::current_exception());
......
...@@ -117,7 +117,7 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -117,7 +117,7 @@ TEST(OpRegistry, IllegalAttr) {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::platform::EnforceNotMet& err) { } catch (paddle::platform::EnforceNotMet& err) {
caught = true; caught = true;
std::string msg = "larger_than check fail"; std::string msg = "OutOfRangeError";
std::string err_msg = err.what(); std::string err_msg = err.what();
ASSERT_TRUE(err_msg.find(msg) != std::string::npos); ASSERT_TRUE(err_msg.find(msg) != std::string::npos);
} }
...@@ -151,7 +151,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -151,7 +151,7 @@ TEST(OpRegistry, CustomChecker) {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::platform::EnforceNotMet& err) { } catch (paddle::platform::EnforceNotMet& err) {
caught = true; caught = true;
std::string msg = "Attribute 'test_attr' is required!"; std::string msg = "InvalidArgumentError";
std::string err_msg = err.what(); std::string err_msg = err.what();
ASSERT_TRUE(err_msg.find(msg) != std::string::npos); ASSERT_TRUE(err_msg.find(msg) != std::string::npos);
} }
......
...@@ -155,8 +155,9 @@ class OperatorBase { ...@@ -155,8 +155,9 @@ class OperatorBase {
bool HasAttr(const std::string& name) const { return attrs_.count(name); } bool HasAttr(const std::string& name) const { return attrs_.count(name); }
template <typename T> template <typename T>
inline const T& Attr(const std::string& name) const { inline const T& Attr(const std::string& name) const {
PADDLE_ENFORCE(attrs_.find(name) != attrs_.end(), PADDLE_ENFORCE_NE(
"%s should be in AttributeMap", name); attrs_.find(name), attrs_.end(),
platform::errors::NotFound("(%s) is not found in AttributeMap.", name));
return BOOST_GET_CONST(T, attrs_.at(name)); return BOOST_GET_CONST(T, attrs_.at(name));
} }
const AttributeMap& Attrs() const { return attrs_; } const AttributeMap& Attrs() const { return attrs_; }
...@@ -165,7 +166,9 @@ class OperatorBase { ...@@ -165,7 +166,9 @@ class OperatorBase {
const VariableNameMap& Outputs() const { return outputs_; } const VariableNameMap& Outputs() const { return outputs_; }
const OpInfo& Info() const { const OpInfo& Info() const {
PADDLE_ENFORCE_NOT_NULL(info_, "OpInfo of %s is not found", type_); PADDLE_ENFORCE_NOT_NULL(
info_, platform::errors::NotFound(
"OpInfo of operator (%s) is not found.", type_));
return *info_; return *info_;
} }
...@@ -369,7 +372,9 @@ class ExecutionContext { ...@@ -369,7 +372,9 @@ class ExecutionContext {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
const inline platform::CUDADeviceContext& cuda_device_context() const { const inline platform::CUDADeviceContext& cuda_device_context() const {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(device_context_.GetPlace()), true); PADDLE_ENFORCE_EQ(platform::is_gpu_place(device_context_.GetPlace()), true,
platform::errors::PreconditionNotMet(
"Current device context place is not GPUPlace."));
return *reinterpret_cast<const platform::CUDADeviceContext*>( return *reinterpret_cast<const platform::CUDADeviceContext*>(
&device_context_); &device_context_);
} }
...@@ -384,8 +389,12 @@ class ExecutionContext { ...@@ -384,8 +389,12 @@ class ExecutionContext {
auto shared_allocation = std::shared_ptr<memory::allocation::Allocation>( auto shared_allocation = std::shared_ptr<memory::allocation::Allocation>(
allocation_ptr, deleter); allocation_ptr, deleter);
PADDLE_ENFORCE_GE(allocation_ptr->size(), PADDLE_ENFORCE_GE(
framework::product(dim) * sizeof(T)); allocation_ptr->size(), framework::product(dim) * sizeof(T),
platform::errors::PreconditionNotMet(
"The data memory size(%d) is less than the tensor needed memory "
"size(%d).",
allocation_ptr->size(), framework::product(dim) * sizeof(T)));
paddle::framework::Tensor temp_tensor( paddle::framework::Tensor temp_tensor(
framework::ToDataType(std::type_index(typeid(T)))); framework::ToDataType(std::type_index(typeid(T))));
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
DECLARE_bool(enable_unused_var_check); DECLARE_bool(enable_unused_var_check);
...@@ -546,12 +547,13 @@ class GetLoDLevelTest : public OperatorWithKernel { ...@@ -546,12 +547,13 @@ class GetLoDLevelTest : public OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true, OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "GetLoDLevelTest");
"Input(X) should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GetLoDLevelTest");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) should not be null."); auto lod_level = ctx->GetLoDLevel("X");
PADDLE_ENFORCE_GT(ctx->GetLoDLevel("X"), 0, PADDLE_ENFORCE_GT(lod_level, 0,
"The LoD level Input(X) should be larger than 0."); paddle::platform::errors::InvalidArgument(
"The LoD level Input(X) should be larger than 0."));
} }
}; };
...@@ -561,10 +563,8 @@ class SetLoDLevelTest : public OperatorWithKernel { ...@@ -561,10 +563,8 @@ class SetLoDLevelTest : public OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true, OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "SetLoDLevelTest");
"Input(X) should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SetLoDLevelTest");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) should not be null.");
ctx->SetLoDLevel("Out", 1); ctx->SetLoDLevel("Out", 1);
} }
}; };
......
...@@ -250,6 +250,7 @@ void PipelineTrainer::Finalize() { ...@@ -250,6 +250,7 @@ void PipelineTrainer::Finalize() {
} }
} }
root_scope_->DropKids(); root_scope_->DropKids();
SectionWorker::ResetBatchId();
} }
Scope* PipelineTrainer::GetWorkerScope(int thread_id) { Scope* PipelineTrainer::GetWorkerScope(int thread_id) {
......
...@@ -122,7 +122,7 @@ class SelectedRows { ...@@ -122,7 +122,7 @@ class SelectedRows {
/* /*
* @brief Get the index of the key from id_to_index_ map. * @brief Get the index of the key from id_to_index_ map.
*/ */
inline int64_t GetIndexFromId(int64_t key) { inline int64_t GetIndexFromId(int64_t key) const {
auto iter = id_to_index_.find(key); auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) { if (iter == id_to_index_.end()) {
return -1; return -1;
......
...@@ -108,8 +108,15 @@ const DDim& Tensor::dims() const { return dims_; } ...@@ -108,8 +108,15 @@ const DDim& Tensor::dims() const { return dims_; }
int64_t Tensor::numel() const { return product(dims_); } int64_t Tensor::numel() const { return product(dims_); }
void Tensor::ResetHolder(std::shared_ptr<memory::Allocation> holder) { void Tensor::ResetHolder(std::shared_ptr<memory::Allocation> holder) {
PADDLE_ENFORCE_EQ(
offset_, 0,
platform::errors::Fatal(
"Only the offset is supported to zero when the holder is reset."));
if (holder_) { if (holder_) {
PADDLE_ENFORCE_EQ(numel() * SizeOfType(type()), holder->size()); PADDLE_ENFORCE_LE(
numel() * SizeOfType(type()) + offset_, holder->size(),
paddle::platform::errors::InvalidArgument(
"The size of Holder is not enough to store the Tensor."));
} }
holder_ = holder; holder_ = holder;
} }
......
...@@ -55,8 +55,13 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, ...@@ -55,8 +55,13 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
BOOST_GET_CONST(platform::CPUPlace, src_place), src_ptr, size); BOOST_GET_CONST(platform::CPUPlace, src_place), src_ptr, size);
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src_place) && // NOLINT else if (platform::is_cuda_pinned_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) { platform::is_cpu_place(dst_place)) {
memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr,
BOOST_GET_CONST(platform::CUDAPinnedPlace, src_place), src_ptr,
size);
} else if (platform::is_gpu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
auto src_gpu_place = BOOST_GET_CONST(platform::CUDAPlace, src_place); auto src_gpu_place = BOOST_GET_CONST(platform::CUDAPlace, src_place);
auto dst_cpu_place = BOOST_GET_CONST(platform::CPUPlace, dst_place); auto dst_cpu_place = BOOST_GET_CONST(platform::CPUPlace, dst_place);
auto ctx_place = ctx.GetPlace(); auto ctx_place = ctx.GetPlace();
...@@ -77,6 +82,28 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, ...@@ -77,6 +82,28 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
auto stream = auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream(); reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream); memory::Copy(dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream);
} else if (platform::is_cuda_pinned_place(src_place) &&
platform::is_gpu_place(dst_place)) {
auto src_cuda_pinned_place =
BOOST_GET_CONST(platform::CUDAPinnedPlace, src_place);
auto dst_gpu_place = BOOST_GET_CONST(platform::CUDAPlace, dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx_place), true,
platform::errors::PreconditionNotMet(
"Device context place mismatch. When copying Tensor "
"data from CUDA Pinned memory to GPU memory, current "
"device context place should be GPU."));
auto ctx_gpu_place = BOOST_GET_CONST(platform::CUDAPlace, ctx_place);
PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place,
platform::errors::PreconditionNotMet(
"The target GPU device and current device context do "
"not match. The target GPU device number is %d, but "
"device context GPU number is %d.",
dst_gpu_place.device, ctx_gpu_place.device));
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(dst_gpu_place, dst_ptr, src_cuda_pinned_place, src_ptr, size,
stream);
} else if (platform::is_gpu_place(src_place) && } else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) { platform::is_gpu_place(dst_place)) {
auto src_gpu_place = BOOST_GET_CONST(platform::CUDAPlace, src_place); auto src_gpu_place = BOOST_GET_CONST(platform::CUDAPlace, src_place);
...@@ -148,8 +175,13 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, ...@@ -148,8 +175,13 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
BOOST_GET_CONST(platform::CPUPlace, src_place), src_ptr, size); BOOST_GET_CONST(platform::CPUPlace, src_place), src_ptr, size);
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src_place) && // NOLINT else if (platform::is_cuda_pinned_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) { platform::is_cpu_place(dst_place)) {
memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr,
BOOST_GET_CONST(platform::CUDAPinnedPlace, src_place), src_ptr,
size);
} else if (platform::is_gpu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
auto src_gpu_place = BOOST_GET_CONST(platform::CUDAPlace, src_place); auto src_gpu_place = BOOST_GET_CONST(platform::CUDAPlace, src_place);
auto dst_cpu_place = BOOST_GET_CONST(platform::CPUPlace, dst_place); auto dst_cpu_place = BOOST_GET_CONST(platform::CPUPlace, dst_place);
memory::Copy(dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, nullptr); memory::Copy(dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, nullptr);
......
...@@ -22,6 +22,8 @@ void TrainerBase::SetScope(Scope* root_scope) { root_scope_ = root_scope; } ...@@ -22,6 +22,8 @@ void TrainerBase::SetScope(Scope* root_scope) { root_scope_ = root_scope; }
void TrainerBase::ParseDumpConfig(const TrainerDesc& desc) { void TrainerBase::ParseDumpConfig(const TrainerDesc& desc) {
dump_fields_path_ = desc.dump_fields_path(); dump_fields_path_ = desc.dump_fields_path();
need_dump_field_ = false;
need_dump_param_ = false;
if (dump_fields_path_ == "") { if (dump_fields_path_ == "") {
VLOG(2) << "dump_fields_path_ is empty"; VLOG(2) << "dump_fields_path_ is empty";
return; return;
......
...@@ -79,5 +79,6 @@ void CopyVariable(const Variable &src_var, Variable *dst_var) { ...@@ -79,5 +79,6 @@ void CopyVariable(const Variable &src_var, Variable *dst_var) {
PADDLE_THROW("unknown var type to copy"); PADDLE_THROW("unknown var type to copy");
} }
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
cc_library(imperative_flag SRCS flags.cc DEPS gflags) cc_library(imperative_flag SRCS flags.cc DEPS gflags)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform) cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform)
cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry) cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry)
......
...@@ -33,8 +33,10 @@ ...@@ -33,8 +33,10 @@
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) { void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy,
bool retain_graph) {
backward_strategy_ = strategy; backward_strategy_ = strategy;
retain_graph_ = retain_graph;
init_node_ = var->GradVarBase()->GradNode(); init_node_ = var->GradVarBase()->GradNode();
var->GradVarBase()->ClearGradNode(); var->GradVarBase()->ClearGradNode();
...@@ -226,7 +228,9 @@ void BasicEngine::Execute() { ...@@ -226,7 +228,9 @@ void BasicEngine::Execute() {
need_accu_var_list_.clear(); need_accu_var_list_.clear();
VLOG(3) << "Remove op after op " << cur_op.Type() << " runs"; VLOG(3) << "Remove op after op " << cur_op.Type() << " runs";
cur_op.ClearBackwardTrace(); if (!retain_graph_) {
cur_op.ClearBackwardTrace();
}
} }
// Step 3: Collect ready ops // Step 3: Collect ready ops
......
...@@ -30,7 +30,8 @@ class OpBase; ...@@ -30,7 +30,8 @@ class OpBase;
class BasicEngine : public Engine { class BasicEngine : public Engine {
public: public:
void Init(VarBase* var, const detail::BackwardStrategy& strategy); void Init(VarBase* var, const detail::BackwardStrategy& strategy,
bool retain_graph = false);
void Execute() override; void Execute() override;
...@@ -51,6 +52,7 @@ class BasicEngine : public Engine { ...@@ -51,6 +52,7 @@ class BasicEngine : public Engine {
accumulators_; accumulators_;
std::vector<std::pair<GradientAccumulator*, std::shared_ptr<VariableWrapper>>> std::vector<std::pair<GradientAccumulator*, std::shared_ptr<VariableWrapper>>>
need_accu_var_list_; need_accu_var_list_;
bool retain_graph_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -28,6 +28,11 @@ ...@@ -28,6 +28,11 @@
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
DECLARE_bool(use_mkldnn);
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
...@@ -192,6 +197,9 @@ void VarBase::ClearGradient() { ...@@ -192,6 +197,9 @@ void VarBase::ClearGradient() {
auto* grad_t = auto* grad_t =
grad_var_->MutableVar()->GetMutable<framework::SelectedRows>(); grad_var_->MutableVar()->GetMutable<framework::SelectedRows>();
if (grad_t->mutable_value()->IsInitialized()) { if (grad_t->mutable_value()->IsInitialized()) {
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) ClearMKLDNNCache(grad_t->place());
#endif
grad_t->mutable_rows()->clear(); grad_t->mutable_rows()->clear();
grad_t->mutable_value()->clear(); grad_t->mutable_value()->clear();
} }
...@@ -202,6 +210,9 @@ void VarBase::ClearGradient() { ...@@ -202,6 +210,9 @@ void VarBase::ClearGradient() {
auto* dev_ctx = auto* dev_ctx =
platform::DeviceContextPool::Instance().Get(grad_t->place()); platform::DeviceContextPool::Instance().Get(grad_t->place());
operators::math::set_constant(*dev_ctx, grad_t, 0.0); operators::math::set_constant(*dev_ctx, grad_t, 0.0);
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) ClearMKLDNNCache(grad_t->place());
#endif
} }
} }
} }
......
...@@ -36,6 +36,15 @@ ...@@ -36,6 +36,15 @@
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
struct HashPair {
template <class T1, class T2>
size_t operator()(const std::pair<T1, T2> &p) const noexcept {
auto hash1 = std::hash<T1>{}(p.first);
auto hash2 = std::hash<T2>{}(p.second);
return hash1 ^ hash2;
}
};
/** /**
* This function prunes the graph to get the ops between `output_targets` * This function prunes the graph to get the ops between `output_targets`
* and `input_target_grads`. * and `input_target_grads`.
...@@ -152,8 +161,10 @@ static void GetGraphInfoBetweenTargets( ...@@ -152,8 +161,10 @@ static void GetGraphInfoBetweenTargets(
target_vars = *input_target_grads; target_vars = *input_target_grads;
std::queue<std::pair<OpBase * /*op*/, OpBase * /*pending op*/>> op_queue; std::queue<std::pair<OpBase * /*op*/, OpBase * /*pending op*/>> op_queue;
std::unordered_set<std::pair<OpBase *, OpBase *>, HashPair> op_base_visited;
for (auto &endpoint_op : endpoint_ops) { for (auto &endpoint_op : endpoint_ops) {
op_queue.emplace(endpoint_op, nullptr); op_queue.emplace(endpoint_op, nullptr);
op_base_visited.emplace(endpoint_op, nullptr);
} }
while (!op_queue.empty()) { while (!op_queue.empty()) {
...@@ -207,6 +218,7 @@ static void GetGraphInfoBetweenTargets( ...@@ -207,6 +218,7 @@ static void GetGraphInfoBetweenTargets(
if (pending_op) { if (pending_op) {
VLOG(10) << "Pending op of " << op->Type() << " is " VLOG(10) << "Pending op of " << op->Type() << " is "
<< pending_op->Type(); << pending_op->Type();
pending_ops[op].insert(pending_op); pending_ops[op].insert(pending_op);
++op_deps[pending_op]; ++op_deps[pending_op];
} else { } else {
...@@ -216,7 +228,10 @@ static void GetGraphInfoBetweenTargets( ...@@ -216,7 +228,10 @@ static void GetGraphInfoBetweenTargets(
auto iter = preceding_ops.find(op); auto iter = preceding_ops.find(op);
if (iter != preceding_ops.end()) { if (iter != preceding_ops.end()) {
for (auto &preceding_op : iter->second) { for (auto &preceding_op : iter->second) {
op_queue.emplace(preceding_op, op); if (op_base_visited.count(std::make_pair(preceding_op, op)) == 0) {
op_queue.emplace(preceding_op, op);
op_base_visited.emplace(preceding_op, op);
}
} }
} }
} }
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册