From 3358455c86b2f1a0ff72892ea361f7bfe43fda7e Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 31 Oct 2019 10:35:53 +0800 Subject: [PATCH] Polish and arrange code in enforce.h (#20901) --- paddle/fluid/platform/enforce.h | 448 ++++++++++++++++---------------- 1 file changed, 222 insertions(+), 226 deletions(-) diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 9263ab401bf..5ede7220ba6 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -52,11 +52,30 @@ limitations under the License. */ #endif // __APPLE__ #endif // PADDLE_WITH_CUDA -#define WITH_SIMPLE_TRACEBACK - namespace paddle { namespace platform { +/** HELPER MACROS AND FUNCTIONS **/ + +// Because most enforce conditions would evaluate to true, we can use +// __builtin_expect to instruct the C++ compiler to generate code that +// always forces branch prediction of true. +// This generates faster binary code. __builtin_expect is since C++11. +// For more details, please check https://stackoverflow.com/a/43870188/724872. +#if !defined(_WIN32) +#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) +#else +// there is no equivalent intrinsics in msvc. +#define UNLIKELY(condition) (condition) +#endif + +#if !defined(_WIN32) +#define LIKELY(condition) __builtin_expect(static_cast(condition), 1) +#else +// there is no equivalent intrinsics in msvc. +#define LIKELY(condition) (condition) +#endif + #ifdef __GNUC__ inline std::string demangle(std::string name) { int status = -4; // some arbitrary value to eliminate the compiler warning @@ -68,6 +87,82 @@ inline std::string demangle(std::string name) { inline std::string demangle(std::string name) { return name; } #endif +namespace details { +template +inline constexpr bool IsArithmetic() { + return std::is_arithmetic::value; +} + +template +struct TypeConverterImpl { + using Type1 = typename std::common_type::type; + using Type2 = Type1; +}; + +template +struct TypeConverterImpl { + using Type1 = T1; + using Type2 = T2; +}; + +template +struct TypeConverter { + private: + static constexpr bool kIsArithmetic = + IsArithmetic() && IsArithmetic(); + + public: + using Type1 = typename TypeConverterImpl::Type1; + using Type2 = typename TypeConverterImpl::Type2; +}; + +template +using CommonType1 = typename std::add_lvalue_reference< + typename std::add_const::Type1>::type>::type; + +template +using CommonType2 = typename std::add_lvalue_reference< + typename std::add_const::Type2>::type>::type; + +// Here, we use SFINAE to check whether T can be converted to std::string +template +struct CanToString { + private: + using YesType = uint8_t; + using NoType = uint16_t; + + template + static YesType Check(decltype(std::cout << std::declval())) { + return 0; + } + + template + static NoType Check(...) { + return 0; + } + + public: + static constexpr bool kValue = + std::is_same(std::cout))>::value; +}; + +template +struct BinaryCompareMessageConverter { + template + static std::string Convert(const char* expression, const T& value) { + return expression + std::string(":") + string::to_string(value); + } +}; + +template <> +struct BinaryCompareMessageConverter { + template + static const char* Convert(const char* expression, const T& value) { + return expression; + } +}; +} // namespace details + template inline std::string GetTraceBackString(StrType&& what, const char* file, int line) { @@ -86,21 +181,11 @@ inline std::string GetTraceBackString(StrType&& what, const char* file, for (int i = 0; i < size; ++i) { if (dladdr(call_stack[i], &info) && info.dli_sname) { auto demangled = demangle(info.dli_sname); -#ifdef WITH_SIMPLE_TRACEBACK std::string path(info.dli_fname); // C++ traceback info are from core.so if (path.substr(path.length() - 3).compare(".so") == 0) { sout << string::Sprintf("%-3d %s\n", idx++, demangled); } -#else - auto addr_offset = static_cast(call_stack[i]) - - static_cast(info.dli_saddr); - sout << string::Sprintf("%-3d %*0p %s + %zd\n", i, 2 + sizeof(void*) * 2, - call_stack[i], demangled, addr_offset); - } else { - sout << string::Sprintf("%-3d %*0p\n", i, 2 + sizeof(void*) * 2, - call_stack[i]); -#endif } } free(symbols); @@ -115,8 +200,19 @@ inline std::string GetTraceBackString(StrType&& what, const char* file, return sout.str(); } +inline bool is_error(bool stat) { return !stat; } + +inline void throw_on_error(bool stat, const std::string& msg) { +#ifndef REPLACE_ENFORCE_GLOG + throw std::runtime_error(msg); +#else + LOG(FATAL) << msg; +#endif +} + +/** ENFORCE EXCEPTION AND MACROS **/ + struct EnforceNotMet : public std::exception { - std::string err_str_; EnforceNotMet(std::exception_ptr e, const char* file, int line) { try { std::rethrow_exception(e); @@ -124,13 +220,111 @@ struct EnforceNotMet : public std::exception { err_str_ = GetTraceBackString(e.what(), file, line); } } - EnforceNotMet(const std::string& str, const char* file, int line) : err_str_(GetTraceBackString(str, file, line)) {} const char* what() const noexcept override { return err_str_.c_str(); } + + std::string err_str_; }; +#define PADDLE_THROW(...) \ + do { \ + throw ::paddle::platform::EnforceNotMet( \ + ::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \ + } while (0) + +#if defined(__CUDA_ARCH__) +// For cuda, the assertions can affect performance and it is therefore +// recommended to disable them in production code +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion +#define PADDLE_ENFORCE(_IS_NOT_ERROR, __FORMAT, ...) \ + do { \ + if (!(_IS_NOT_ERROR)) { \ + printf("Exception: %s:%d Assertion `%s` failed. " __FORMAT "\n", \ + __FILE__, __LINE__, #_IS_NOT_ERROR, ##__VA_ARGS__); \ + asm("trap;"); \ + } \ + } while (0) +#else +#define PADDLE_ENFORCE(COND, ...) \ + do { \ + auto __cond__ = (COND); \ + if (UNLIKELY(::paddle::platform::is_error(__cond__))) { \ + try { \ + ::paddle::platform::throw_on_error( \ + __cond__, ::paddle::string::Sprintf(__VA_ARGS__)); \ + } catch (...) { \ + throw ::paddle::platform::EnforceNotMet(std::current_exception(), \ + __FILE__, __LINE__); \ + } \ + } \ + } while (0) +#endif + +/* + * Some enforce helpers here, usage: + * int a = 1; + * int b = 2; + * PADDLE_ENFORCE_EQ(a, b); + * + * will raise an expression described as follows: + * "Expected input a == b, but received a(1) != b(2)." + * with detailed stack information. + * + * extra messages is also supported, for example: + * PADDLE_ENFORCE(a, b, "some simple enforce failed between %d numbers", 2) + */ +#define PADDLE_ENFORCE_NOT_NULL(__VAL, ...) \ + do { \ + if (UNLIKELY(nullptr == (__VAL))) { \ + PADDLE_THROW(#__VAL " should not be null\n%s", \ + ::paddle::string::Sprintf(__VA_ARGS__)); \ + } \ + } while (0) + +#define __PADDLE_BINARY_COMPARE(__VAL1, __VAL2, __CMP, __INV_CMP, ...) \ + do { \ + auto __val1 = (__VAL1); \ + auto __val2 = (__VAL2); \ + using __TYPE1__ = decltype(__val1); \ + using __TYPE2__ = decltype(__val2); \ + using __COMMON_TYPE1__ = \ + ::paddle::platform::details::CommonType1<__TYPE1__, __TYPE2__>; \ + using __COMMON_TYPE2__ = \ + ::paddle::platform::details::CommonType2<__TYPE1__, __TYPE2__>; \ + bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \ + static_cast<__COMMON_TYPE2__>(__val2)); \ + if (UNLIKELY(!__is_not_error)) { \ + constexpr bool __kCanToString__ = \ + ::paddle::platform::details::CanToString<__TYPE1__>::kValue && \ + ::paddle::platform::details::CanToString<__TYPE2__>::kValue; \ + PADDLE_THROW("Expected %s " #__CMP " %s, but received %s " #__INV_CMP \ + " %s.\n%s", \ + #__VAL1, #__VAL2, \ + ::paddle::platform::details::BinaryCompareMessageConverter< \ + __kCanToString__>::Convert(#__VAL1, __val1), \ + ::paddle::platform::details::BinaryCompareMessageConverter< \ + __kCanToString__>::Convert(#__VAL2, __val2), \ + ::paddle::string::Sprintf(__VA_ARGS__)); \ + } \ + } while (0) + +#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, ==, !=, __VA_ARGS__) +#define PADDLE_ENFORCE_NE(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, !=, ==, __VA_ARGS__) +#define PADDLE_ENFORCE_GT(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >, <=, __VA_ARGS__) +#define PADDLE_ENFORCE_GE(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >=, <, __VA_ARGS__) +#define PADDLE_ENFORCE_LT(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__) +#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__) + +/** OTHER EXCEPTION AND ENFORCE **/ + struct EOFException : public std::exception { std::string err_str_; EOFException(const char* err_msg, const char* file, int line) { @@ -140,34 +334,19 @@ struct EOFException : public std::exception { const char* what() const noexcept override { return err_str_.c_str(); } }; -// Because most enforce conditions would evaluate to true, we can use -// __builtin_expect to instruct the C++ compiler to generate code that -// always forces branch prediction of true. -// This generates faster binary code. __builtin_expect is since C++11. -// For more details, please check https://stackoverflow.com/a/43870188/724872. -#if !defined(_WIN32) -#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) -#else -// there is no equivalent intrinsics in msvc. -#define UNLIKELY(condition) (condition) -#endif - -#if !defined(_WIN32) -#define LIKELY(condition) __builtin_expect(static_cast(condition), 1) -#else -// there is no equivalent intrinsics in msvc. -#define LIKELY(condition) (condition) -#endif +#define PADDLE_THROW_EOF() \ + do { \ + throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \ + __LINE__); \ + } while (0) -inline bool is_error(bool stat) { return !stat; } +#define PADDLE_THROW_BAD_ALLOC(...) \ + do { \ + throw ::paddle::memory::allocation::BadAlloc( \ + ::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \ + } while (0) -inline void throw_on_error(bool stat, const std::string& msg) { -#ifndef REPLACE_ENFORCE_GLOG - throw std::runtime_error(msg); -#else - LOG(FATAL) << msg; -#endif -} +/** CUDA PADDLE ENFORCE FUNCTIONS AND MACROS **/ #ifdef PADDLE_WITH_CUDA @@ -251,6 +430,7 @@ inline void throw_on_error(ncclResult_t stat, const std::string& msg) { #endif } #endif // __APPLE__ and windows + #endif // PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA @@ -276,41 +456,7 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess); #endif } // namespace details -#endif - -#define PADDLE_THROW(...) \ - do { \ - throw ::paddle::platform::EnforceNotMet( \ - ::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \ - } while (0) - -#if defined(__CUDA_ARCH__) -// For cuda, the assertions can affect performance and it is therefore -// recommended to disable them in production code -// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion -#define PADDLE_ENFORCE(_IS_NOT_ERROR, __FORMAT, ...) \ - do { \ - if (!(_IS_NOT_ERROR)) { \ - printf("Exception: %s:%d Assertion `%s` failed. " __FORMAT "\n", \ - __FILE__, __LINE__, #_IS_NOT_ERROR, ##__VA_ARGS__); \ - asm("trap;"); \ - } \ - } while (0) -#else -#define PADDLE_ENFORCE(COND, ...) \ - do { \ - auto __cond__ = (COND); \ - if (UNLIKELY(::paddle::platform::is_error(__cond__))) { \ - try { \ - ::paddle::platform::throw_on_error( \ - __cond__, ::paddle::string::Sprintf(__VA_ARGS__)); \ - } catch (...) { \ - throw ::paddle::platform::EnforceNotMet(std::current_exception(), \ - __FILE__, __LINE__); \ - } \ - } \ - } while (0) -#endif +#endif // PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA #define PADDLE_ENFORCE_CUDA_SUCCESS(COND, ...) \ @@ -332,157 +478,7 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess); } while (0) #undef DEFINE_CUDA_STATUS_TYPE -#endif - -#define PADDLE_THROW_EOF() \ - do { \ - throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \ - __LINE__); \ - } while (0) - -#define PADDLE_THROW_BAD_ALLOC(...) \ - do { \ - throw ::paddle::memory::allocation::BadAlloc( \ - ::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \ - } while (0) - -/* - * Some enforce helpers here, usage: - * int a = 1; - * int b = 2; - * PADDLE_ENFORCE_EQ(a, b); - * - * will raise an expression described as follows: - * "Expected input a == b, but received a(1) != b(2)." - * with detailed stack information. - * - * extra messages is also supported, for example: - * PADDLE_ENFORCE(a, b, "some simple enforce failed between %d numbers", 2) - */ -#define PADDLE_ENFORCE_NOT_NULL(__VAL, ...) \ - do { \ - if (UNLIKELY(nullptr == (__VAL))) { \ - PADDLE_THROW(#__VAL " should not be null\n%s", \ - ::paddle::string::Sprintf(__VA_ARGS__)); \ - } \ - } while (0) - -namespace details { -template -inline constexpr bool IsArithmetic() { - return std::is_arithmetic::value; -} - -template -struct TypeConverterImpl { - using Type1 = typename std::common_type::type; - using Type2 = Type1; -}; - -template -struct TypeConverterImpl { - using Type1 = T1; - using Type2 = T2; -}; - -template -struct TypeConverter { - private: - static constexpr bool kIsArithmetic = - IsArithmetic() && IsArithmetic(); - - public: - using Type1 = typename TypeConverterImpl::Type1; - using Type2 = typename TypeConverterImpl::Type2; -}; - -template -using CommonType1 = typename std::add_lvalue_reference< - typename std::add_const::Type1>::type>::type; - -template -using CommonType2 = typename std::add_lvalue_reference< - typename std::add_const::Type2>::type>::type; - -// Here, we use SFINAE to check whether T can be converted to std::string -template -struct CanToString { - private: - using YesType = uint8_t; - using NoType = uint16_t; - - template - static YesType Check(decltype(std::cout << std::declval())) { - return 0; - } - - template - static NoType Check(...) { - return 0; - } - - public: - static constexpr bool kValue = - std::is_same(std::cout))>::value; -}; - -template -struct BinaryCompareMessageConverter { - template - static std::string Convert(const char* expression, const T& value) { - return expression + std::string(":") + string::to_string(value); - } -}; - -template <> -struct BinaryCompareMessageConverter { - template - static const char* Convert(const char* expression, const T& value) { - return expression; - } -}; - -} // namespace details - -#define __PADDLE_BINARY_COMPARE(__VAL1, __VAL2, __CMP, __INV_CMP, ...) \ - do { \ - auto __val1 = (__VAL1); \ - auto __val2 = (__VAL2); \ - using __TYPE1__ = decltype(__val1); \ - using __TYPE2__ = decltype(__val2); \ - using __COMMON_TYPE1__ = \ - ::paddle::platform::details::CommonType1<__TYPE1__, __TYPE2__>; \ - using __COMMON_TYPE2__ = \ - ::paddle::platform::details::CommonType2<__TYPE1__, __TYPE2__>; \ - bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \ - static_cast<__COMMON_TYPE2__>(__val2)); \ - if (UNLIKELY(!__is_not_error)) { \ - constexpr bool __kCanToString__ = \ - ::paddle::platform::details::CanToString<__TYPE1__>::kValue && \ - ::paddle::platform::details::CanToString<__TYPE2__>::kValue; \ - PADDLE_THROW("Expected %s " #__CMP " %s, but received %s " #__INV_CMP \ - " %s.\n%s", \ - #__VAL1, #__VAL2, \ - ::paddle::platform::details::BinaryCompareMessageConverter< \ - __kCanToString__>::Convert(#__VAL1, __val1), \ - ::paddle::platform::details::BinaryCompareMessageConverter< \ - __kCanToString__>::Convert(#__VAL2, __val2), \ - ::paddle::string::Sprintf(__VA_ARGS__)); \ - } \ - } while (0) - -#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) \ - __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, ==, !=, __VA_ARGS__) -#define PADDLE_ENFORCE_NE(__VAL0, __VAL1, ...) \ - __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, !=, ==, __VA_ARGS__) -#define PADDLE_ENFORCE_GT(__VAL0, __VAL1, ...) \ - __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >, <=, __VA_ARGS__) -#define PADDLE_ENFORCE_GE(__VAL0, __VAL1, ...) \ - __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >=, <, __VA_ARGS__) -#define PADDLE_ENFORCE_LT(__VAL0, __VAL1, ...) \ - __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__) -#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \ - __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__) +#endif // PADDLE_WITH_CUDA } // namespace platform } // namespace paddle -- GitLab