From af0fbd9012b655d670924e17bd31a050369dffbe Mon Sep 17 00:00:00 2001 From: silingtong123 <35439432+silingtong123@users.noreply.github.com> Date: Mon, 19 Aug 2019 13:25:27 +0800 Subject: [PATCH] change PADDLE_ENFORCE to PADDLE_ENFORCE_CUDA_SUCCESS (#19205) * print error code if cuda related API fails --- paddle/fluid/operators/math/blas_impl.cu.h | 28 ++++++++++++---------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index 58f7be12ce..4188e26fc9 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -31,23 +31,24 @@ template <> struct CUBlas { template static void GEMM(ARGS... args) { - PADDLE_ENFORCE(platform::dynload::cublasSgemm(args...)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemm(args...)); } template static void AXPY(ARGS... args) { - PADDLE_ENFORCE(platform::dynload::cublasSaxpy(args...)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSaxpy(args...)); } template static void GEMV(ARGS... args) { - PADDLE_ENFORCE(platform::dynload::cublasSgemv(args...)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemv(args...)); } template static void GEMM_STRIDED_BATCH(ARGS... args) { #if CUDA_VERSION >= 8000 - PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(args...)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cublasSgemmStridedBatched(args...)); #else PADDLE_THROW("SgemmStridedBatched is not supported on cuda <= 7.5"); #endif @@ -69,7 +70,7 @@ struct CUBlas { VLOG(5) << "use_tensor_op_math: " << (dev_ctx->tensor_core_available() ? "True" : "False"); dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE(platform::dynload::cublasSgemmEx( + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemmEx( handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc)); }); @@ -83,23 +84,24 @@ template <> struct CUBlas { template static void GEMM(ARGS... args) { - PADDLE_ENFORCE(platform::dynload::cublasDgemm(args...)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDgemm(args...)); } template static void AXPY(ARGS... args) { - PADDLE_ENFORCE(platform::dynload::cublasDaxpy(args...)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDaxpy(args...)); } template static void GEMV(ARGS... args) { - PADDLE_ENFORCE(platform::dynload::cublasDgemv(args...)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDgemv(args...)); } template static void GEMM_STRIDED_BATCH(ARGS... args) { #if CUDA_VERSION >= 8000 - PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(args...)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cublasDgemmStridedBatched(args...)); #else PADDLE_THROW("DgemmStridedBatched is not supported on cuda <= 7.5"); #endif @@ -120,7 +122,7 @@ struct CUBlas { const float16 *alpha, const float16 *A, int lda, const float16 *B, int ldb, const float16 *beta, float16 *C, int ldc) { - PADDLE_ENFORCE( + PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cublasHgemm(handle, transa, transb, m, n, k, reinterpret_cast(alpha), reinterpret_cast(A), lda, @@ -140,7 +142,7 @@ struct CUBlas { long long int strideC, // NOLINT int batchCount) { #if CUDA_VERSION >= 8000 - PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched( + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasHgemmStridedBatched( handle, transa, transb, m, n, k, reinterpret_cast(alpha), reinterpret_cast(A), lda, strideA, @@ -174,7 +176,7 @@ struct CUBlas { #endif // CUDA_VERSION >= 9000 dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE(platform::dynload::cublasGemmEx( + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo)); }); @@ -356,7 +358,7 @@ void Blas::BatchedGEMM( << (use_tensor_op_math ? "True" : "False"); context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { - PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx( + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx( handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo)); -- GitLab