From 3b5f3548529faae0dfdd6430b1b0d93a9a0b99df Mon Sep 17 00:00:00 2001 From: silingtong123 <35439432+silingtong123@users.noreply.github.com> Date: Tue, 20 Aug 2019 20:33:44 +0800 Subject: [PATCH] Modify PADDLE_ENFORCE to PADDLE_ENFORCE_CUDA_SUCCESS (#19247) * add PADDLE_ENFORCE_CUDA_SUCCESS, test=develop (#19211) * test=develop,Modify PADDLE_ENFORCE to PADDLE_ENFORCE_CUDA_SUCCESS --- paddle/fluid/operators/math/blas_impl.cu.h | 28 +++++++------ paddle/fluid/platform/enforce.h | 47 ++++++++++++++++++++++ paddle/fluid/platform/enforce_test.cc | 43 ++++++++++++++++++++ 3 files changed, 105 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index 58f7be12ce6..4188e26fc98 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -31,23 +31,24 @@ template <> struct CUBlas { template static void GEMM(ARGS... args) { - PADDLE_ENFORCE(platform::dynload::cublasSgemm(args...)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemm(args...)); } template static void AXPY(ARGS... args) { - PADDLE_ENFORCE(platform::dynload::cublasSaxpy(args...)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSaxpy(args...)); } template static void GEMV(ARGS... args) { - PADDLE_ENFORCE(platform::dynload::cublasSgemv(args...)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemv(args...)); } template static void GEMM_STRIDED_BATCH(ARGS... args) { #if CUDA_VERSION >= 8000 - PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(args...)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cublasSgemmStridedBatched(args...)); #else PADDLE_THROW("SgemmStridedBatched is not supported on cuda <= 7.5"); #endif @@ -69,7 +70,7 @@ struct CUBlas { VLOG(5) << "use_tensor_op_math: " << (dev_ctx->tensor_core_available() ? "True" : "False"); dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE(platform::dynload::cublasSgemmEx( + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemmEx( handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc)); }); @@ -83,23 +84,24 @@ template <> struct CUBlas { template static void GEMM(ARGS... args) { - PADDLE_ENFORCE(platform::dynload::cublasDgemm(args...)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDgemm(args...)); } template static void AXPY(ARGS... args) { - PADDLE_ENFORCE(platform::dynload::cublasDaxpy(args...)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDaxpy(args...)); } template static void GEMV(ARGS... args) { - PADDLE_ENFORCE(platform::dynload::cublasDgemv(args...)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDgemv(args...)); } template static void GEMM_STRIDED_BATCH(ARGS... args) { #if CUDA_VERSION >= 8000 - PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(args...)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cublasDgemmStridedBatched(args...)); #else PADDLE_THROW("DgemmStridedBatched is not supported on cuda <= 7.5"); #endif @@ -120,7 +122,7 @@ struct CUBlas { const float16 *alpha, const float16 *A, int lda, const float16 *B, int ldb, const float16 *beta, float16 *C, int ldc) { - PADDLE_ENFORCE( + PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cublasHgemm(handle, transa, transb, m, n, k, reinterpret_cast(alpha), reinterpret_cast(A), lda, @@ -140,7 +142,7 @@ struct CUBlas { long long int strideC, // NOLINT int batchCount) { #if CUDA_VERSION >= 8000 - PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched( + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasHgemmStridedBatched( handle, transa, transb, m, n, k, reinterpret_cast(alpha), reinterpret_cast(A), lda, strideA, @@ -174,7 +176,7 @@ struct CUBlas { #endif // CUDA_VERSION >= 9000 dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE(platform::dynload::cublasGemmEx( + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo)); }); @@ -356,7 +358,7 @@ void Blas::BatchedGEMM( << (use_tensor_op_math ? "True" : "False"); context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx( + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx( handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo)); diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 127be44525b..24ada37807a 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -236,6 +236,31 @@ inline void throw_on_error(ncclResult_t stat, const std::string& msg) { #endif // __APPLE__ and windows #endif // PADDLE_WITH_CUDA +#ifdef PADDLE_WITH_CUDA +namespace details { + +template +struct CudaStatusType {}; + +#define DEFINE_CUDA_STATUS_TYPE(type, success_value) \ + template <> \ + struct CudaStatusType { \ + using Type = type; \ + static constexpr Type kSuccess = success_value; \ + } + +DEFINE_CUDA_STATUS_TYPE(cudaError_t, cudaSuccess); +DEFINE_CUDA_STATUS_TYPE(curandStatus_t, CURAND_STATUS_SUCCESS); +DEFINE_CUDA_STATUS_TYPE(cudnnStatus_t, CUDNN_STATUS_SUCCESS); +DEFINE_CUDA_STATUS_TYPE(cublasStatus_t, CUBLAS_STATUS_SUCCESS); + +#if !defined(__APPLE__) && !defined(_WIN32) +DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess); +#endif + +} // namespace details +#endif + #define PADDLE_THROW(...) \ do { \ throw ::paddle::platform::EnforceNotMet( \ @@ -256,6 +281,28 @@ inline void throw_on_error(ncclResult_t stat, const std::string& msg) { } \ } while (0) +#ifdef PADDLE_WITH_CUDA +#define PADDLE_ENFORCE_CUDA_SUCCESS(COND, ...) \ + do { \ + auto __cond__ = (COND); \ + using __CUDA_STATUS_TYPE__ = decltype(__cond__); \ + constexpr auto __success_type__ = \ + ::paddle::platform::details::CudaStatusType< \ + __CUDA_STATUS_TYPE__>::kSuccess; \ + if (UNLIKELY(__cond__ != __success_type__)) { \ + try { \ + ::paddle::platform::throw_on_error( \ + __cond__, ::paddle::string::Sprintf(__VA_ARGS__)); \ + } catch (...) { \ + throw ::paddle::platform::EnforceNotMet(std::current_exception(), \ + __FILE__, __LINE__); \ + } \ + } \ + } while (0) + +#undef DEFINE_CUDA_STATUS_TYPE +#endif + #define PADDLE_THROW_EOF() \ do { \ throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \ diff --git a/paddle/fluid/platform/enforce_test.cc b/paddle/fluid/platform/enforce_test.cc index adcc95367f1..ceba13b4d64 100644 --- a/paddle/fluid/platform/enforce_test.cc +++ b/paddle/fluid/platform/enforce_test.cc @@ -253,3 +253,46 @@ TEST(EOF_EXCEPTION, THROW_EOF) { } EXPECT_TRUE(caught_eof); } + +#ifdef PADDLE_WITH_CUDA +template +bool CheckCudaStatusSuccess(T value, const std::string& msg = "success") { + PADDLE_ENFORCE_CUDA_SUCCESS(value, msg); + return true; +} + +template +bool CheckCudaStatusFailure( + T value, const std::string& msg = "self-defined cuda status failed") { + try { + PADDLE_ENFORCE_CUDA_SUCCESS(value, msg); + return false; + } catch (paddle::platform::EnforceNotMet& error) { + std::string ex_msg = error.what(); + return ex_msg.find(msg) != std::string::npos; + } +} + +TEST(enforce, cuda_success) { + EXPECT_TRUE(CheckCudaStatusSuccess(cudaSuccess)); + EXPECT_TRUE(CheckCudaStatusFailure(cudaErrorInvalidValue)); + EXPECT_TRUE(CheckCudaStatusFailure(cudaErrorMemoryAllocation)); + + EXPECT_TRUE(CheckCudaStatusSuccess(CURAND_STATUS_SUCCESS)); + EXPECT_TRUE(CheckCudaStatusFailure(CURAND_STATUS_VERSION_MISMATCH)); + EXPECT_TRUE(CheckCudaStatusFailure(CURAND_STATUS_NOT_INITIALIZED)); + + EXPECT_TRUE(CheckCudaStatusSuccess(CUDNN_STATUS_SUCCESS)); + EXPECT_TRUE(CheckCudaStatusFailure(CUDNN_STATUS_NOT_INITIALIZED)); + EXPECT_TRUE(CheckCudaStatusFailure(CUDNN_STATUS_ALLOC_FAILED)); + + EXPECT_TRUE(CheckCudaStatusSuccess(CUBLAS_STATUS_SUCCESS)); + EXPECT_TRUE(CheckCudaStatusFailure(CUBLAS_STATUS_NOT_INITIALIZED)); + EXPECT_TRUE(CheckCudaStatusFailure(CUBLAS_STATUS_INVALID_VALUE)); +#if !defined(__APPLE__) && !defined(_WIN32) + EXPECT_TRUE(CheckCudaStatusSuccess(ncclSuccess)); + EXPECT_TRUE(CheckCudaStatusFailure(ncclUnhandledCudaError)); + EXPECT_TRUE(CheckCudaStatusFailure(ncclSystemError)); +#endif +} +#endif -- GitLab