提交 e3f5fdcc 编写于 作者: Y Yu Yang

Make PADDLE_ENFORCE and PADDLE_THROW catchable

* Use EnforceNotMet to unify all exception types.
上级 0c2790f7
...@@ -36,6 +36,21 @@ limitations under the License. */ ...@@ -36,6 +36,21 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
struct EnforceNotMet : public std::exception {
std::exception_ptr exp_;
std::string err_str_;
EnforceNotMet(std::exception_ptr e, const char* f, int l) : exp_(e) {
try {
std::rethrow_exception(exp_);
} catch (const std::exception& exp) {
err_str_ = string::Sprintf("%s at [%s:%d]", exp.what(), f, l);
}
}
const char* what() const noexcept { return err_str_.c_str(); }
};
// Because most enforce conditions would evaluate to true, we can use // Because most enforce conditions would evaluate to true, we can use
// __builtin_expect to instruct the C++ compiler to generate code that // __builtin_expect to instruct the C++ compiler to generate code that
// always forces branch prediction of true. // always forces branch prediction of true.
...@@ -52,9 +67,7 @@ template <typename... Args> ...@@ -52,9 +67,7 @@ template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
int stat, const Args&... args) { int stat, const Args&... args) {
if (UNLIKELY(!(stat))) { if (UNLIKELY(!(stat))) {
throw std::runtime_error( throw std::runtime_error(string::Sprintf(args...));
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
} }
} }
...@@ -64,12 +77,8 @@ template <typename... Args> ...@@ -64,12 +77,8 @@ template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cudaError_t e, const Args&... args) { cudaError_t e, const Args&... args) {
if (UNLIKELY(e)) { if (UNLIKELY(e)) {
// clang-format off throw thrust::system_error(e, thrust::cuda_category(),
throw thrust::system_error( string::Sprintf(args...));
e, thrust::cuda_category(),
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
} }
} }
...@@ -77,12 +86,8 @@ template <typename... Args> ...@@ -77,12 +86,8 @@ template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
curandStatus_t stat, const Args&... args) { curandStatus_t stat, const Args&... args) {
if (stat != CURAND_STATUS_SUCCESS) { if (stat != CURAND_STATUS_SUCCESS) {
// clang-format off throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(),
throw thrust::system_error( string::Sprintf(args...));
cudaErrorLaunchFailure, thrust::cuda_category(),
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
} }
} }
...@@ -92,12 +97,8 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( ...@@ -92,12 +97,8 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
if (stat == CUDNN_STATUS_SUCCESS) { if (stat == CUDNN_STATUS_SUCCESS) {
return; return;
} else { } else {
// clang-format off throw std::runtime_error(platform::dynload::cudnnGetErrorString(stat) +
throw std::runtime_error( string::Sprintf(args...));
platform::dynload::cudnnGetErrorString(stat) +
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
} }
} }
...@@ -126,22 +127,27 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( ...@@ -126,22 +127,27 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
} else if (stat == CUBLAS_STATUS_LICENSE_ERROR) { } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) {
err = "CUBLAS: license error, "; err = "CUBLAS: license error, ";
} }
throw std::runtime_error(err + string::Sprintf(args...) + throw std::runtime_error(err + string::Sprintf(args...));
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
} }
#endif // PADDLE_ONLY_CPU #endif // PADDLE_ONLY_CPU
#define PADDLE_THROW(...) \ #define PADDLE_THROW(...) \
do { \ do { \
throw std::runtime_error( \ throw ::paddle::platform::EnforceNotMet( \
string::Sprintf(__VA_ARGS__) + \ std::make_exception_ptr( \
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); \ std::runtime_error(string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \
} while (0) } while (0)
#define PADDLE_ENFORCE(...) \ #define PADDLE_ENFORCE(...) \
do { \ do { \
try { \
::paddle::platform::throw_on_error(__VA_ARGS__); \ ::paddle::platform::throw_on_error(__VA_ARGS__); \
} catch (...) { \
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
__FILE__, __LINE__); \
} \
} while (0) } while (0)
} // namespace platform } // namespace platform
......
...@@ -23,7 +23,7 @@ TEST(ENFORCE, FAILED) { ...@@ -23,7 +23,7 @@ TEST(ENFORCE, FAILED) {
bool in_catch = false; bool in_catch = false;
try { try {
PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123); PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123);
} catch (const std::runtime_error& error) { } catch (paddle::platform::EnforceNotMet error) {
// your error handling code here // your error handling code here
in_catch = true; in_catch = true;
std::string msg = "Enforce is not ok 123 at all"; std::string msg = "Enforce is not ok 123 at all";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册