From e3f5fdcc7a242ec8d65e20554bbc2ceb79c0c900 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 25 Jul 2017 17:04:45 +0800 Subject: [PATCH] Make PADDLE_ENFORCE and PADDLE_THROW catchable * Use EnforceNotMet to unify all exception types. --- paddle/platform/enforce.h | 68 ++++++++++++++++++--------------- paddle/platform/enforce_test.cc | 2 +- 2 files changed, 38 insertions(+), 32 deletions(-) diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index b06ab8a2f1..a3a10fc07f 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -36,6 +36,21 @@ limitations under the License. */ namespace paddle { namespace platform { +struct EnforceNotMet : public std::exception { + std::exception_ptr exp_; + std::string err_str_; + + EnforceNotMet(std::exception_ptr e, const char* f, int l) : exp_(e) { + try { + std::rethrow_exception(exp_); + } catch (const std::exception& exp) { + err_str_ = string::Sprintf("%s at [%s:%d]", exp.what(), f, l); + } + } + + const char* what() const noexcept { return err_str_.c_str(); } +}; + // Because most enforce conditions would evaluate to true, we can use // __builtin_expect to instruct the C++ compiler to generate code that // always forces branch prediction of true. @@ -52,9 +67,7 @@ template inline typename std::enable_if::type throw_on_error( int stat, const Args&... args) { if (UNLIKELY(!(stat))) { - throw std::runtime_error( - string::Sprintf(args...) + - string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); + throw std::runtime_error(string::Sprintf(args...)); } } @@ -64,12 +77,8 @@ template inline typename std::enable_if::type throw_on_error( cudaError_t e, const Args&... args) { if (UNLIKELY(e)) { - // clang-format off - throw thrust::system_error( - e, thrust::cuda_category(), - string::Sprintf(args...) + - string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); - // clang-format on + throw thrust::system_error(e, thrust::cuda_category(), + string::Sprintf(args...)); } } @@ -77,12 +86,8 @@ template inline typename std::enable_if::type throw_on_error( curandStatus_t stat, const Args&... args) { if (stat != CURAND_STATUS_SUCCESS) { - // clang-format off - throw thrust::system_error( - cudaErrorLaunchFailure, thrust::cuda_category(), - string::Sprintf(args...) + - string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); - // clang-format on + throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(), + string::Sprintf(args...)); } } @@ -92,12 +97,8 @@ inline typename std::enable_if::type throw_on_error( if (stat == CUDNN_STATUS_SUCCESS) { return; } else { - // clang-format off - throw std::runtime_error( - platform::dynload::cudnnGetErrorString(stat) + - string::Sprintf(args...) + - string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); - // clang-format on + throw std::runtime_error(platform::dynload::cudnnGetErrorString(stat) + + string::Sprintf(args...)); } } @@ -126,22 +127,27 @@ inline typename std::enable_if::type throw_on_error( } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) { err = "CUBLAS: license error, "; } - throw std::runtime_error(err + string::Sprintf(args...) + - string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); + throw std::runtime_error(err + string::Sprintf(args...)); } #endif // PADDLE_ONLY_CPU -#define PADDLE_THROW(...) \ - do { \ - throw std::runtime_error( \ - string::Sprintf(__VA_ARGS__) + \ - string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); \ +#define PADDLE_THROW(...) \ + do { \ + throw ::paddle::platform::EnforceNotMet( \ + std::make_exception_ptr( \ + std::runtime_error(string::Sprintf(__VA_ARGS__))), \ + __FILE__, __LINE__); \ } while (0) -#define PADDLE_ENFORCE(...) \ - do { \ - ::paddle::platform::throw_on_error(__VA_ARGS__); \ +#define PADDLE_ENFORCE(...) \ + do { \ + try { \ + ::paddle::platform::throw_on_error(__VA_ARGS__); \ + } catch (...) { \ + throw ::paddle::platform::EnforceNotMet(std::current_exception(), \ + __FILE__, __LINE__); \ + } \ } while (0) } // namespace platform diff --git a/paddle/platform/enforce_test.cc b/paddle/platform/enforce_test.cc index d7152f8150..2ac31812a8 100644 --- a/paddle/platform/enforce_test.cc +++ b/paddle/platform/enforce_test.cc @@ -23,7 +23,7 @@ TEST(ENFORCE, FAILED) { bool in_catch = false; try { PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123); - } catch (const std::runtime_error& error) { + } catch (paddle::platform::EnforceNotMet error) { // your error handling code here in_catch = true; std::string msg = "Enforce is not ok 123 at all"; -- GitLab