未验证 提交 3358455c 编写于 作者: C Chen Weihang 提交者: GitHub

Polish and arrange code in enforce.h (#20901)

上级 728099d0
...@@ -52,11 +52,30 @@ limitations under the License. */ ...@@ -52,11 +52,30 @@ limitations under the License. */
#endif // __APPLE__ #endif // __APPLE__
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
#define WITH_SIMPLE_TRACEBACK
namespace paddle { namespace paddle {
namespace platform { namespace platform {
/** HELPER MACROS AND FUNCTIONS **/
// Because most enforce conditions would evaluate to true, we can use
// __builtin_expect to instruct the C++ compiler to generate code that
// always forces branch prediction of true.
// This generates faster binary code. __builtin_expect is since C++11.
// For more details, please check https://stackoverflow.com/a/43870188/724872.
#if !defined(_WIN32)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else
// there is no equivalent intrinsics in msvc.
#define UNLIKELY(condition) (condition)
#endif
#if !defined(_WIN32)
#define LIKELY(condition) __builtin_expect(static_cast<bool>(condition), 1)
#else
// there is no equivalent intrinsics in msvc.
#define LIKELY(condition) (condition)
#endif
#ifdef __GNUC__ #ifdef __GNUC__
inline std::string demangle(std::string name) { inline std::string demangle(std::string name) {
int status = -4; // some arbitrary value to eliminate the compiler warning int status = -4; // some arbitrary value to eliminate the compiler warning
...@@ -68,6 +87,82 @@ inline std::string demangle(std::string name) { ...@@ -68,6 +87,82 @@ inline std::string demangle(std::string name) {
inline std::string demangle(std::string name) { return name; } inline std::string demangle(std::string name) { return name; }
#endif #endif
namespace details {
template <typename T>
inline constexpr bool IsArithmetic() {
return std::is_arithmetic<T>::value;
}
template <typename T1, typename T2, bool kIsArithmetic /* = true */>
struct TypeConverterImpl {
using Type1 = typename std::common_type<T1, T2>::type;
using Type2 = Type1;
};
template <typename T1, typename T2>
struct TypeConverterImpl<T1, T2, false> {
using Type1 = T1;
using Type2 = T2;
};
template <typename T1, typename T2>
struct TypeConverter {
private:
static constexpr bool kIsArithmetic =
IsArithmetic<T1>() && IsArithmetic<T2>();
public:
using Type1 = typename TypeConverterImpl<T1, T2, kIsArithmetic>::Type1;
using Type2 = typename TypeConverterImpl<T1, T2, kIsArithmetic>::Type2;
};
template <typename T1, typename T2>
using CommonType1 = typename std::add_lvalue_reference<
typename std::add_const<typename TypeConverter<T1, T2>::Type1>::type>::type;
template <typename T1, typename T2>
using CommonType2 = typename std::add_lvalue_reference<
typename std::add_const<typename TypeConverter<T1, T2>::Type2>::type>::type;
// Here, we use SFINAE to check whether T can be converted to std::string
template <typename T>
struct CanToString {
private:
using YesType = uint8_t;
using NoType = uint16_t;
template <typename U>
static YesType Check(decltype(std::cout << std::declval<U>())) {
return 0;
}
template <typename U>
static NoType Check(...) {
return 0;
}
public:
static constexpr bool kValue =
std::is_same<YesType, decltype(Check<T>(std::cout))>::value;
};
template <bool kCanToString /* = true */>
struct BinaryCompareMessageConverter {
template <typename T>
static std::string Convert(const char* expression, const T& value) {
return expression + std::string(":") + string::to_string(value);
}
};
template <>
struct BinaryCompareMessageConverter<false> {
template <typename T>
static const char* Convert(const char* expression, const T& value) {
return expression;
}
};
} // namespace details
template <typename StrType> template <typename StrType>
inline std::string GetTraceBackString(StrType&& what, const char* file, inline std::string GetTraceBackString(StrType&& what, const char* file,
int line) { int line) {
...@@ -86,21 +181,11 @@ inline std::string GetTraceBackString(StrType&& what, const char* file, ...@@ -86,21 +181,11 @@ inline std::string GetTraceBackString(StrType&& what, const char* file,
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
if (dladdr(call_stack[i], &info) && info.dli_sname) { if (dladdr(call_stack[i], &info) && info.dli_sname) {
auto demangled = demangle(info.dli_sname); auto demangled = demangle(info.dli_sname);
#ifdef WITH_SIMPLE_TRACEBACK
std::string path(info.dli_fname); std::string path(info.dli_fname);
// C++ traceback info are from core.so // C++ traceback info are from core.so
if (path.substr(path.length() - 3).compare(".so") == 0) { if (path.substr(path.length() - 3).compare(".so") == 0) {
sout << string::Sprintf("%-3d %s\n", idx++, demangled); sout << string::Sprintf("%-3d %s\n", idx++, demangled);
} }
#else
auto addr_offset = static_cast<char*>(call_stack[i]) -
static_cast<char*>(info.dli_saddr);
sout << string::Sprintf("%-3d %*0p %s + %zd\n", i, 2 + sizeof(void*) * 2,
call_stack[i], demangled, addr_offset);
} else {
sout << string::Sprintf("%-3d %*0p\n", i, 2 + sizeof(void*) * 2,
call_stack[i]);
#endif
} }
} }
free(symbols); free(symbols);
...@@ -115,8 +200,19 @@ inline std::string GetTraceBackString(StrType&& what, const char* file, ...@@ -115,8 +200,19 @@ inline std::string GetTraceBackString(StrType&& what, const char* file,
return sout.str(); return sout.str();
} }
inline bool is_error(bool stat) { return !stat; }
inline void throw_on_error(bool stat, const std::string& msg) {
#ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(msg);
#else
LOG(FATAL) << msg;
#endif
}
/** ENFORCE EXCEPTION AND MACROS **/
struct EnforceNotMet : public std::exception { struct EnforceNotMet : public std::exception {
std::string err_str_;
EnforceNotMet(std::exception_ptr e, const char* file, int line) { EnforceNotMet(std::exception_ptr e, const char* file, int line) {
try { try {
std::rethrow_exception(e); std::rethrow_exception(e);
...@@ -124,13 +220,111 @@ struct EnforceNotMet : public std::exception { ...@@ -124,13 +220,111 @@ struct EnforceNotMet : public std::exception {
err_str_ = GetTraceBackString(e.what(), file, line); err_str_ = GetTraceBackString(e.what(), file, line);
} }
} }
EnforceNotMet(const std::string& str, const char* file, int line) EnforceNotMet(const std::string& str, const char* file, int line)
: err_str_(GetTraceBackString(str, file, line)) {} : err_str_(GetTraceBackString(str, file, line)) {}
const char* what() const noexcept override { return err_str_.c_str(); } const char* what() const noexcept override { return err_str_.c_str(); }
std::string err_str_;
}; };
#define PADDLE_THROW(...) \
do { \
throw ::paddle::platform::EnforceNotMet( \
::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \
} while (0)
#if defined(__CUDA_ARCH__)
// For cuda, the assertions can affect performance and it is therefore
// recommended to disable them in production code
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion
#define PADDLE_ENFORCE(_IS_NOT_ERROR, __FORMAT, ...) \
do { \
if (!(_IS_NOT_ERROR)) { \
printf("Exception: %s:%d Assertion `%s` failed. " __FORMAT "\n", \
__FILE__, __LINE__, #_IS_NOT_ERROR, ##__VA_ARGS__); \
asm("trap;"); \
} \
} while (0)
#else
#define PADDLE_ENFORCE(COND, ...) \
do { \
auto __cond__ = (COND); \
if (UNLIKELY(::paddle::platform::is_error(__cond__))) { \
try { \
::paddle::platform::throw_on_error( \
__cond__, ::paddle::string::Sprintf(__VA_ARGS__)); \
} catch (...) { \
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
__FILE__, __LINE__); \
} \
} \
} while (0)
#endif
/*
* Some enforce helpers here, usage:
* int a = 1;
* int b = 2;
* PADDLE_ENFORCE_EQ(a, b);
*
* will raise an expression described as follows:
* "Expected input a == b, but received a(1) != b(2)."
* with detailed stack information.
*
* extra messages is also supported, for example:
* PADDLE_ENFORCE(a, b, "some simple enforce failed between %d numbers", 2)
*/
#define PADDLE_ENFORCE_NOT_NULL(__VAL, ...) \
do { \
if (UNLIKELY(nullptr == (__VAL))) { \
PADDLE_THROW(#__VAL " should not be null\n%s", \
::paddle::string::Sprintf(__VA_ARGS__)); \
} \
} while (0)
#define __PADDLE_BINARY_COMPARE(__VAL1, __VAL2, __CMP, __INV_CMP, ...) \
do { \
auto __val1 = (__VAL1); \
auto __val2 = (__VAL2); \
using __TYPE1__ = decltype(__val1); \
using __TYPE2__ = decltype(__val2); \
using __COMMON_TYPE1__ = \
::paddle::platform::details::CommonType1<__TYPE1__, __TYPE2__>; \
using __COMMON_TYPE2__ = \
::paddle::platform::details::CommonType2<__TYPE1__, __TYPE2__>; \
bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \
static_cast<__COMMON_TYPE2__>(__val2)); \
if (UNLIKELY(!__is_not_error)) { \
constexpr bool __kCanToString__ = \
::paddle::platform::details::CanToString<__TYPE1__>::kValue && \
::paddle::platform::details::CanToString<__TYPE2__>::kValue; \
PADDLE_THROW("Expected %s " #__CMP " %s, but received %s " #__INV_CMP \
" %s.\n%s", \
#__VAL1, #__VAL2, \
::paddle::platform::details::BinaryCompareMessageConverter< \
__kCanToString__>::Convert(#__VAL1, __val1), \
::paddle::platform::details::BinaryCompareMessageConverter< \
__kCanToString__>::Convert(#__VAL2, __val2), \
::paddle::string::Sprintf(__VA_ARGS__)); \
} \
} while (0)
#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, ==, !=, __VA_ARGS__)
#define PADDLE_ENFORCE_NE(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, !=, ==, __VA_ARGS__)
#define PADDLE_ENFORCE_GT(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >, <=, __VA_ARGS__)
#define PADDLE_ENFORCE_GE(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >=, <, __VA_ARGS__)
#define PADDLE_ENFORCE_LT(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__)
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)
/** OTHER EXCEPTION AND ENFORCE **/
struct EOFException : public std::exception { struct EOFException : public std::exception {
std::string err_str_; std::string err_str_;
EOFException(const char* err_msg, const char* file, int line) { EOFException(const char* err_msg, const char* file, int line) {
...@@ -140,34 +334,19 @@ struct EOFException : public std::exception { ...@@ -140,34 +334,19 @@ struct EOFException : public std::exception {
const char* what() const noexcept override { return err_str_.c_str(); } const char* what() const noexcept override { return err_str_.c_str(); }
}; };
// Because most enforce conditions would evaluate to true, we can use #define PADDLE_THROW_EOF() \
// __builtin_expect to instruct the C++ compiler to generate code that do { \
// always forces branch prediction of true. throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \
// This generates faster binary code. __builtin_expect is since C++11. __LINE__); \
// For more details, please check https://stackoverflow.com/a/43870188/724872. } while (0)
#if !defined(_WIN32)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else
// there is no equivalent intrinsics in msvc.
#define UNLIKELY(condition) (condition)
#endif
#if !defined(_WIN32)
#define LIKELY(condition) __builtin_expect(static_cast<bool>(condition), 1)
#else
// there is no equivalent intrinsics in msvc.
#define LIKELY(condition) (condition)
#endif
inline bool is_error(bool stat) { return !stat; } #define PADDLE_THROW_BAD_ALLOC(...) \
do { \
throw ::paddle::memory::allocation::BadAlloc( \
::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \
} while (0)
inline void throw_on_error(bool stat, const std::string& msg) { /** CUDA PADDLE ENFORCE FUNCTIONS AND MACROS **/
#ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(msg);
#else
LOG(FATAL) << msg;
#endif
}
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -251,6 +430,7 @@ inline void throw_on_error(ncclResult_t stat, const std::string& msg) { ...@@ -251,6 +430,7 @@ inline void throw_on_error(ncclResult_t stat, const std::string& msg) {
#endif #endif
} }
#endif // __APPLE__ and windows #endif // __APPLE__ and windows
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -276,41 +456,7 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess); ...@@ -276,41 +456,7 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess);
#endif #endif
} // namespace details } // namespace details
#endif #endif // PADDLE_WITH_CUDA
#define PADDLE_THROW(...) \
do { \
throw ::paddle::platform::EnforceNotMet( \
::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \
} while (0)
#if defined(__CUDA_ARCH__)
// For cuda, the assertions can affect performance and it is therefore
// recommended to disable them in production code
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion
#define PADDLE_ENFORCE(_IS_NOT_ERROR, __FORMAT, ...) \
do { \
if (!(_IS_NOT_ERROR)) { \
printf("Exception: %s:%d Assertion `%s` failed. " __FORMAT "\n", \
__FILE__, __LINE__, #_IS_NOT_ERROR, ##__VA_ARGS__); \
asm("trap;"); \
} \
} while (0)
#else
#define PADDLE_ENFORCE(COND, ...) \
do { \
auto __cond__ = (COND); \
if (UNLIKELY(::paddle::platform::is_error(__cond__))) { \
try { \
::paddle::platform::throw_on_error( \
__cond__, ::paddle::string::Sprintf(__VA_ARGS__)); \
} catch (...) { \
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
__FILE__, __LINE__); \
} \
} \
} while (0)
#endif
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#define PADDLE_ENFORCE_CUDA_SUCCESS(COND, ...) \ #define PADDLE_ENFORCE_CUDA_SUCCESS(COND, ...) \
...@@ -332,157 +478,7 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess); ...@@ -332,157 +478,7 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess);
} while (0) } while (0)
#undef DEFINE_CUDA_STATUS_TYPE #undef DEFINE_CUDA_STATUS_TYPE
#endif #endif // PADDLE_WITH_CUDA
#define PADDLE_THROW_EOF() \
do { \
throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \
__LINE__); \
} while (0)
#define PADDLE_THROW_BAD_ALLOC(...) \
do { \
throw ::paddle::memory::allocation::BadAlloc( \
::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \
} while (0)
/*
* Some enforce helpers here, usage:
* int a = 1;
* int b = 2;
* PADDLE_ENFORCE_EQ(a, b);
*
* will raise an expression described as follows:
* "Expected input a == b, but received a(1) != b(2)."
* with detailed stack information.
*
* extra messages is also supported, for example:
* PADDLE_ENFORCE(a, b, "some simple enforce failed between %d numbers", 2)
*/
#define PADDLE_ENFORCE_NOT_NULL(__VAL, ...) \
do { \
if (UNLIKELY(nullptr == (__VAL))) { \
PADDLE_THROW(#__VAL " should not be null\n%s", \
::paddle::string::Sprintf(__VA_ARGS__)); \
} \
} while (0)
namespace details {
template <typename T>
inline constexpr bool IsArithmetic() {
return std::is_arithmetic<T>::value;
}
template <typename T1, typename T2, bool kIsArithmetic /* = true */>
struct TypeConverterImpl {
using Type1 = typename std::common_type<T1, T2>::type;
using Type2 = Type1;
};
template <typename T1, typename T2>
struct TypeConverterImpl<T1, T2, false> {
using Type1 = T1;
using Type2 = T2;
};
template <typename T1, typename T2>
struct TypeConverter {
private:
static constexpr bool kIsArithmetic =
IsArithmetic<T1>() && IsArithmetic<T2>();
public:
using Type1 = typename TypeConverterImpl<T1, T2, kIsArithmetic>::Type1;
using Type2 = typename TypeConverterImpl<T1, T2, kIsArithmetic>::Type2;
};
template <typename T1, typename T2>
using CommonType1 = typename std::add_lvalue_reference<
typename std::add_const<typename TypeConverter<T1, T2>::Type1>::type>::type;
template <typename T1, typename T2>
using CommonType2 = typename std::add_lvalue_reference<
typename std::add_const<typename TypeConverter<T1, T2>::Type2>::type>::type;
// Here, we use SFINAE to check whether T can be converted to std::string
template <typename T>
struct CanToString {
private:
using YesType = uint8_t;
using NoType = uint16_t;
template <typename U>
static YesType Check(decltype(std::cout << std::declval<U>())) {
return 0;
}
template <typename U>
static NoType Check(...) {
return 0;
}
public:
static constexpr bool kValue =
std::is_same<YesType, decltype(Check<T>(std::cout))>::value;
};
template <bool kCanToString /* = true */>
struct BinaryCompareMessageConverter {
template <typename T>
static std::string Convert(const char* expression, const T& value) {
return expression + std::string(":") + string::to_string(value);
}
};
template <>
struct BinaryCompareMessageConverter<false> {
template <typename T>
static const char* Convert(const char* expression, const T& value) {
return expression;
}
};
} // namespace details
#define __PADDLE_BINARY_COMPARE(__VAL1, __VAL2, __CMP, __INV_CMP, ...) \
do { \
auto __val1 = (__VAL1); \
auto __val2 = (__VAL2); \
using __TYPE1__ = decltype(__val1); \
using __TYPE2__ = decltype(__val2); \
using __COMMON_TYPE1__ = \
::paddle::platform::details::CommonType1<__TYPE1__, __TYPE2__>; \
using __COMMON_TYPE2__ = \
::paddle::platform::details::CommonType2<__TYPE1__, __TYPE2__>; \
bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \
static_cast<__COMMON_TYPE2__>(__val2)); \
if (UNLIKELY(!__is_not_error)) { \
constexpr bool __kCanToString__ = \
::paddle::platform::details::CanToString<__TYPE1__>::kValue && \
::paddle::platform::details::CanToString<__TYPE2__>::kValue; \
PADDLE_THROW("Expected %s " #__CMP " %s, but received %s " #__INV_CMP \
" %s.\n%s", \
#__VAL1, #__VAL2, \
::paddle::platform::details::BinaryCompareMessageConverter< \
__kCanToString__>::Convert(#__VAL1, __val1), \
::paddle::platform::details::BinaryCompareMessageConverter< \
__kCanToString__>::Convert(#__VAL2, __val2), \
::paddle::string::Sprintf(__VA_ARGS__)); \
} \
} while (0)
#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, ==, !=, __VA_ARGS__)
#define PADDLE_ENFORCE_NE(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, !=, ==, __VA_ARGS__)
#define PADDLE_ENFORCE_GT(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >, <=, __VA_ARGS__)
#define PADDLE_ENFORCE_GE(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >=, <, __VA_ARGS__)
#define PADDLE_ENFORCE_LT(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__)
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册