提交 af0fbd90 编写于 作者: S silingtong123 提交者: liuwei1031

change PADDLE_ENFORCE to PADDLE_ENFORCE_CUDA_SUCCESS (#19205)

* print error code if cuda related API fails
上级 2f0dc846
...@@ -31,23 +31,24 @@ template <> ...@@ -31,23 +31,24 @@ template <>
struct CUBlas<float> { struct CUBlas<float> {
template <typename... ARGS> template <typename... ARGS>
static void GEMM(ARGS... args) { static void GEMM(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasSgemm(args...)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemm(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void AXPY(ARGS... args) { static void AXPY(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasSaxpy(args...)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSaxpy(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMV(ARGS... args) { static void GEMV(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasSgemv(args...)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemv(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMM_STRIDED_BATCH(ARGS... args) { static void GEMM_STRIDED_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(args...)); PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasSgemmStridedBatched(args...));
#else #else
PADDLE_THROW("SgemmStridedBatched is not supported on cuda <= 7.5"); PADDLE_THROW("SgemmStridedBatched is not supported on cuda <= 7.5");
#endif #endif
...@@ -69,7 +70,7 @@ struct CUBlas<float> { ...@@ -69,7 +70,7 @@ struct CUBlas<float> {
VLOG(5) << "use_tensor_op_math: " VLOG(5) << "use_tensor_op_math: "
<< (dev_ctx->tensor_core_available() ? "True" : "False"); << (dev_ctx->tensor_core_available() ? "True" : "False");
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemmEx(
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
beta, C, Ctype, ldc)); beta, C, Ctype, ldc));
}); });
...@@ -83,23 +84,24 @@ template <> ...@@ -83,23 +84,24 @@ template <>
struct CUBlas<double> { struct CUBlas<double> {
template <typename... ARGS> template <typename... ARGS>
static void GEMM(ARGS... args) { static void GEMM(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasDgemm(args...)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDgemm(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void AXPY(ARGS... args) { static void AXPY(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasDaxpy(args...)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDaxpy(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMV(ARGS... args) { static void GEMV(ARGS... args) {
PADDLE_ENFORCE(platform::dynload::cublasDgemv(args...)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDgemv(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMM_STRIDED_BATCH(ARGS... args) { static void GEMM_STRIDED_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(args...)); PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasDgemmStridedBatched(args...));
#else #else
PADDLE_THROW("DgemmStridedBatched is not supported on cuda <= 7.5"); PADDLE_THROW("DgemmStridedBatched is not supported on cuda <= 7.5");
#endif #endif
...@@ -120,7 +122,7 @@ struct CUBlas<platform::float16> { ...@@ -120,7 +122,7 @@ struct CUBlas<platform::float16> {
const float16 *alpha, const float16 *A, int lda, const float16 *alpha, const float16 *A, int lda,
const float16 *B, int ldb, const float16 *beta, float16 *C, const float16 *B, int ldb, const float16 *beta, float16 *C,
int ldc) { int ldc) {
PADDLE_ENFORCE( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasHgemm(handle, transa, transb, m, n, k, platform::dynload::cublasHgemm(handle, transa, transb, m, n, k,
reinterpret_cast<const __half *>(alpha), reinterpret_cast<const __half *>(alpha),
reinterpret_cast<const __half *>(A), lda, reinterpret_cast<const __half *>(A), lda,
...@@ -140,7 +142,7 @@ struct CUBlas<platform::float16> { ...@@ -140,7 +142,7 @@ struct CUBlas<platform::float16> {
long long int strideC, // NOLINT long long int strideC, // NOLINT
int batchCount) { int batchCount) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasHgemmStridedBatched(
handle, transa, transb, m, n, k, handle, transa, transb, m, n, k,
reinterpret_cast<const __half *>(alpha), reinterpret_cast<const __half *>(alpha),
reinterpret_cast<const __half *>(A), lda, strideA, reinterpret_cast<const __half *>(A), lda, strideA,
...@@ -174,7 +176,7 @@ struct CUBlas<platform::float16> { ...@@ -174,7 +176,7 @@ struct CUBlas<platform::float16> {
#endif // CUDA_VERSION >= 9000 #endif // CUDA_VERSION >= 9000
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasGemmEx( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmEx(
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
beta, C, Ctype, ldc, computeType, algo)); beta, C, Ctype, ldc, computeType, algo));
}); });
...@@ -356,7 +358,7 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM( ...@@ -356,7 +358,7 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
<< (use_tensor_op_math ? "True" : "False"); << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb, handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb,
strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc, strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc,
strideC, batchCount, CUDA_R_32F, algo)); strideC, batchCount, CUDA_R_32F, algo));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册