未验证 提交 28172bbb 编写于 作者: Y Yan Chunwei 提交者: GitHub

add debug to replacing enforce with GLOG for debug (#11244)

上级 06da7402
......@@ -61,6 +61,7 @@ option(EIGEN_USE_THREADS "Compile with multi-threaded Eigen" OFF)
option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF)
option(WITH_FAST_BUNDLE_TEST "Bundle tests that can be run in a single process together to reduce launch overhead" OFF)
option(WITH_CONTRIB "Compile the third-party contributation" OFF)
option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better debug." OFF)
option(WITH_ANAKIN "Compile with Anakin library" OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
......@@ -131,6 +132,10 @@ if (NOT DEFINED WITH_MKLDNN)
set(WITH_MKLDNN OFF)
endif()
endif()
if (REPLACE_ENFORCE_GLOG)
add_definitions("-DREPLACE_ENFORCE_GLOG")
endif()
########################################################################################
include(external/mklml) # download mklml package
......
......@@ -23,9 +23,9 @@ namespace framework {
template <typename T>
inline const T* Tensor::data() const {
check_memory_size();
PADDLE_ENFORCE(std::is_same<T, void>::value ||
holder_->type() == std::type_index(typeid(T)),
"Tensor holds the wrong type, it holds %s",
bool valid = std::is_same<T, void>::value ||
holder_->type() == std::type_index(typeid(T));
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s",
this->holder_->type().name());
return reinterpret_cast<const T*>(
......@@ -37,9 +37,9 @@ inline bool Tensor::IsInitialized() const { return holder_ != nullptr; }
template <typename T>
inline T* Tensor::data() {
check_memory_size();
PADDLE_ENFORCE(std::is_same<T, void>::value ||
holder_->type() == std::type_index(typeid(T)),
"Tensor holds the wrong type, it holds %s",
bool valid = std::is_same<T, void>::value ||
holder_->type() == std::type_index(typeid(T));
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s",
this->holder_->type().name());
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
......
......@@ -113,7 +113,11 @@ template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
bool stat, const Args&... args) {
if (UNLIKELY(!(stat))) {
#ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(string::Sprintf(args...));
#else
LOG(FATAL) << string::Sprintf(args...);
#endif
}
}
......@@ -123,8 +127,12 @@ template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cudaError_t e, const Args&... args) {
if (UNLIKELY(e)) {
#ifndef REPLACE_ENFORCE_GLOG
throw thrust::system_error(e, thrust::cuda_category(),
string::Sprintf(args...));
#else
LOG(FATAL) << string::Sprintf(args...);
#endif
}
}
......@@ -132,8 +140,12 @@ template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
curandStatus_t stat, const Args&... args) {
if (stat != CURAND_STATUS_SUCCESS) {
#ifndef REPLACE_ENFORCE_GLOG
throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(),
string::Sprintf(args...));
#else
LOG(FATAL) << string::Sprintf(args...);
#endif
}
}
......@@ -143,8 +155,12 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
if (stat == CUDNN_STATUS_SUCCESS) {
return;
} else {
#ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(platform::dynload::cudnnGetErrorString(stat) +
string::Sprintf(args...));
#else
LOG(FATAL) << string::Sprintf(args...);
#endif
}
}
......@@ -173,7 +189,11 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
} else if (stat == CUBLAS_STATUS_LICENSE_ERROR) {
err = "CUBLAS: license error, ";
}
#ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(err + string::Sprintf(args...));
#else
LOG(FATAL) << err << string::Sprintf(args...);
#endif
}
#ifndef __APPLE__
......@@ -183,8 +203,13 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
if (stat == ncclSuccess) {
return;
} else {
#ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(platform::dynload::ncclGetErrorString(stat) +
string::Sprintf(args...));
#else
LOG(FATAL) << platform::dynload::ncclGetErrorString(stat)
<< string::Sprintf(args...);
#endif
}
}
#endif // __APPLE__
......@@ -203,6 +228,7 @@ inline void throw_on_error(T e) {
__FILE__, __LINE__); \
} while (false)
#ifndef REPLACE_ENFORCE_GLOG
#define PADDLE_ENFORCE(...) \
do { \
try { \
......@@ -212,6 +238,9 @@ inline void throw_on_error(T e) {
__FILE__, __LINE__); \
} \
} while (false)
#else
#define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__);
#endif
/*
* Some enforce helpers here, usage:
......
......@@ -84,7 +84,7 @@ void Fprintf(std::ostream& out, const char* fmt, const Args&... args) {
}
template <typename... Args>
std::string Sprintf(const char* fmt, const Args&... args) {
std::string Sprintf(const char* fmt = "", const Args&... args) {
std::ostringstream oss;
Fprintf(oss, fmt, args...);
return oss.str();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册