diff --git a/paddle/fluid/platform/dynload/cublas.cc b/paddle/fluid/platform/dynload/cublas.cc index 361d3439b844e9f68d3fba0a0e41ec457118a4a9..41648c32fe6f98bb0b78ea7891065e5586f70463 100644 --- a/paddle/fluid/platform/dynload/cublas.cc +++ b/paddle/fluid/platform/dynload/cublas.cc @@ -32,6 +32,9 @@ CUBLAS_BLAS_ROUTINE_EACH_R2(DEFINE_WRAP); CUBLAS_BLAS_ROUTINE_EACH_R3(DEFINE_WRAP); #endif +#ifdef CUBLAS_BLAS_ROUTINE_EACH_R4 +CUBLAS_BLAS_ROUTINE_EACH_R4(DEFINE_WRAP); +#endif } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h index ff80bd525c167eda00f67d392c7b3b71436ee820..ced789b90d067218c3b01d124cfd2c93dc94e528 100644 --- a/paddle/fluid/platform/dynload/cublas.h +++ b/paddle/fluid/platform/dynload/cublas.h @@ -90,23 +90,33 @@ CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) // APIs available after CUDA 8.0 #if CUDA_VERSION >= 8000 -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmEx); -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmStridedBatched); -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmStridedBatched); -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmStridedBatched); -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmStridedBatched); -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasHgemmStridedBatched); +#define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) \ + __macro(cublasGemmEx); \ + __macro(cublasSgemmStridedBatched); \ + __macro(cublasDgemmStridedBatched); \ + __macro(cublasCgemmStridedBatched); \ + __macro(cublasZgemmStridedBatched); \ + __macro(cublasHgemmStridedBatched); + +CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) #endif // APIs available after CUDA 9.0 #if CUDA_VERSION >= 9000 -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSetMathMode); -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGetMathMode); +#define CUBLAS_BLAS_ROUTINE_EACH_R3(__macro) \ + __macro(cublasSetMathMode); \ + __macro(cublasGetMathMode); + +CUBLAS_BLAS_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) #endif +// APIs available after CUDA 9.1 #if CUDA_VERSION >= 9010 -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmBatchedEx); -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmStridedBatchedEx); +#define CUBLAS_BLAS_ROUTINE_EACH_R4(__macro) \ + __macro(cublasGemmBatchedEx); \ + __macro(cublasGemmStridedBatchedEx); + +CUBLAS_BLAS_ROUTINE_EACH_R4(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) #endif #undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP