diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 127be44525beca0e2273e591cf2ea5fb332782b4..24ada37807a079ced2de71febb68411e9533a8a2 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -236,6 +236,31 @@ inline void throw_on_error(ncclResult_t stat, const std::string& msg) { #endif // __APPLE__ and windows #endif // PADDLE_WITH_CUDA +#ifdef PADDLE_WITH_CUDA +namespace details { + +template +struct CudaStatusType {}; + +#define DEFINE_CUDA_STATUS_TYPE(type, success_value) \ + template <> \ + struct CudaStatusType { \ + using Type = type; \ + static constexpr Type kSuccess = success_value; \ + } + +DEFINE_CUDA_STATUS_TYPE(cudaError_t, cudaSuccess); +DEFINE_CUDA_STATUS_TYPE(curandStatus_t, CURAND_STATUS_SUCCESS); +DEFINE_CUDA_STATUS_TYPE(cudnnStatus_t, CUDNN_STATUS_SUCCESS); +DEFINE_CUDA_STATUS_TYPE(cublasStatus_t, CUBLAS_STATUS_SUCCESS); + +#if !defined(__APPLE__) && !defined(_WIN32) +DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess); +#endif + +} // namespace details +#endif + #define PADDLE_THROW(...) \ do { \ throw ::paddle::platform::EnforceNotMet( \ @@ -256,6 +281,28 @@ inline void throw_on_error(ncclResult_t stat, const std::string& msg) { } \ } while (0) +#ifdef PADDLE_WITH_CUDA +#define PADDLE_ENFORCE_CUDA_SUCCESS(COND, ...) \ + do { \ + auto __cond__ = (COND); \ + using __CUDA_STATUS_TYPE__ = decltype(__cond__); \ + constexpr auto __success_type__ = \ + ::paddle::platform::details::CudaStatusType< \ + __CUDA_STATUS_TYPE__>::kSuccess; \ + if (UNLIKELY(__cond__ != __success_type__)) { \ + try { \ + ::paddle::platform::throw_on_error( \ + __cond__, ::paddle::string::Sprintf(__VA_ARGS__)); \ + } catch (...) { \ + throw ::paddle::platform::EnforceNotMet(std::current_exception(), \ + __FILE__, __LINE__); \ + } \ + } \ + } while (0) + +#undef DEFINE_CUDA_STATUS_TYPE +#endif + #define PADDLE_THROW_EOF() \ do { \ throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \ diff --git a/paddle/fluid/platform/enforce_test.cc b/paddle/fluid/platform/enforce_test.cc index adcc95367f11dfa2722226e5a0386bedfa6e746e..ceba13b4d64ebc7d8c76e5eb5bfa3d92aba26d36 100644 --- a/paddle/fluid/platform/enforce_test.cc +++ b/paddle/fluid/platform/enforce_test.cc @@ -253,3 +253,46 @@ TEST(EOF_EXCEPTION, THROW_EOF) { } EXPECT_TRUE(caught_eof); } + +#ifdef PADDLE_WITH_CUDA +template +bool CheckCudaStatusSuccess(T value, const std::string& msg = "success") { + PADDLE_ENFORCE_CUDA_SUCCESS(value, msg); + return true; +} + +template +bool CheckCudaStatusFailure( + T value, const std::string& msg = "self-defined cuda status failed") { + try { + PADDLE_ENFORCE_CUDA_SUCCESS(value, msg); + return false; + } catch (paddle::platform::EnforceNotMet& error) { + std::string ex_msg = error.what(); + return ex_msg.find(msg) != std::string::npos; + } +} + +TEST(enforce, cuda_success) { + EXPECT_TRUE(CheckCudaStatusSuccess(cudaSuccess)); + EXPECT_TRUE(CheckCudaStatusFailure(cudaErrorInvalidValue)); + EXPECT_TRUE(CheckCudaStatusFailure(cudaErrorMemoryAllocation)); + + EXPECT_TRUE(CheckCudaStatusSuccess(CURAND_STATUS_SUCCESS)); + EXPECT_TRUE(CheckCudaStatusFailure(CURAND_STATUS_VERSION_MISMATCH)); + EXPECT_TRUE(CheckCudaStatusFailure(CURAND_STATUS_NOT_INITIALIZED)); + + EXPECT_TRUE(CheckCudaStatusSuccess(CUDNN_STATUS_SUCCESS)); + EXPECT_TRUE(CheckCudaStatusFailure(CUDNN_STATUS_NOT_INITIALIZED)); + EXPECT_TRUE(CheckCudaStatusFailure(CUDNN_STATUS_ALLOC_FAILED)); + + EXPECT_TRUE(CheckCudaStatusSuccess(CUBLAS_STATUS_SUCCESS)); + EXPECT_TRUE(CheckCudaStatusFailure(CUBLAS_STATUS_NOT_INITIALIZED)); + EXPECT_TRUE(CheckCudaStatusFailure(CUBLAS_STATUS_INVALID_VALUE)); +#if !defined(__APPLE__) && !defined(_WIN32) + EXPECT_TRUE(CheckCudaStatusSuccess(ncclSuccess)); + EXPECT_TRUE(CheckCudaStatusFailure(ncclUnhandledCudaError)); + EXPECT_TRUE(CheckCudaStatusFailure(ncclSystemError)); +#endif +} +#endif