提交 f7847ca6 编写于 作者: C chengduozh

fix cublas warp error

test=develop
上级 45312813
...@@ -32,6 +32,9 @@ CUBLAS_BLAS_ROUTINE_EACH_R2(DEFINE_WRAP); ...@@ -32,6 +32,9 @@ CUBLAS_BLAS_ROUTINE_EACH_R2(DEFINE_WRAP);
CUBLAS_BLAS_ROUTINE_EACH_R3(DEFINE_WRAP); CUBLAS_BLAS_ROUTINE_EACH_R3(DEFINE_WRAP);
#endif #endif
#ifdef CUBLAS_BLAS_ROUTINE_EACH_R4
CUBLAS_BLAS_ROUTINE_EACH_R4(DEFINE_WRAP);
#endif
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -90,23 +90,33 @@ CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) ...@@ -90,23 +90,33 @@ CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
// APIs available after CUDA 8.0 // APIs available after CUDA 8.0
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmEx); #define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) \
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmStridedBatched); __macro(cublasGemmEx); \
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmStridedBatched); __macro(cublasSgemmStridedBatched); \
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmStridedBatched); __macro(cublasDgemmStridedBatched); \
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmStridedBatched); __macro(cublasCgemmStridedBatched); \
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasHgemmStridedBatched); __macro(cublasZgemmStridedBatched); \
__macro(cublasHgemmStridedBatched);
CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
#endif #endif
// APIs available after CUDA 9.0 // APIs available after CUDA 9.0
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSetMathMode); #define CUBLAS_BLAS_ROUTINE_EACH_R3(__macro) \
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGetMathMode); __macro(cublasSetMathMode); \
__macro(cublasGetMathMode);
CUBLAS_BLAS_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
#endif #endif
// APIs available after CUDA 9.1
#if CUDA_VERSION >= 9010 #if CUDA_VERSION >= 9010
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmBatchedEx); #define CUBLAS_BLAS_ROUTINE_EACH_R4(__macro) \
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmStridedBatchedEx); __macro(cublasGemmBatchedEx); \
__macro(cublasGemmStridedBatchedEx);
CUBLAS_BLAS_ROUTINE_EACH_R4(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
#endif #endif
#undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP #undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册