From 69d76812ae3c9e43f46f7a24175c2795ae9034d4 Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 11 Jul 2017 19:15:48 +0800 Subject: [PATCH] fix cublas dynload bug --- paddle/platform/dynload/cublas.cc | 4 ++-- paddle/platform/dynload/cublas.h | 31 +++++++++++++++++-------------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/paddle/platform/dynload/cublas.cc b/paddle/platform/dynload/cublas.cc index f83fcf34d74..4e3dfdaefb2 100644 --- a/paddle/platform/dynload/cublas.cc +++ b/paddle/platform/dynload/cublas.cc @@ -6,10 +6,10 @@ namespace dynload { std::once_flag cublas_dso_flag; void *cublas_dso_handle = nullptr; -#define DEFINE_WRAP(__name) DynLoad__##__name __name; +#define DEFINE_WRAP(__name) DynLoad__##__name __name CUBLAS_BLAS_ROUTINE_EACH(DEFINE_WRAP); } // namespace dynload } // namespace platform -} // namespace paddle \ No newline at end of file +} // namespace paddle diff --git a/paddle/platform/dynload/cublas.h b/paddle/platform/dynload/cublas.h index 1332be31b13..47c7a8ec21f 100644 --- a/paddle/platform/dynload/cublas.h +++ b/paddle/platform/dynload/cublas.h @@ -58,26 +58,29 @@ extern void *cublas_dso_handle; extern DynLoad__##__name __name #endif +#define DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) \ + DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) + #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ __macro(cublasSgemv); \ __macro(cublasDgemv); \ __macro(cublasSgemm); \ __macro(cublasDgemm); \ __macro(cublasSgeam); \ - __macro(cublasDgeam); \ - __macro(cublasCreate); \ - __macro(cublasDestroy); \ - __macro(cublasSetStream); \ - __macro(cublasSetPointerMode); \ - __macro(cublasGetPointerMode); \ - __macro(cublasSgemmBatched); \ - __macro(cublasDgemmBatched); \ - __macro(cublasCgemmBatched); \ - __macro(cublasZgemmBatched); \ - __macro(cublasSgetrfBatched); \ - __macro(cublasSgetriBatched); \ - __macro(cublasDgetrfBatched); \ - __macro(cublasDgetriBatched) + __macro(cublasDgeam); + +DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasCreate); +DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasDestroy); +DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetStream); +DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetPointerMode); +DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasGetPointerMode); +DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched); +DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched); +DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched); +DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched); +DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetrfBatched); +DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetriBatched); +DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetrfBatched); CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP); -- GitLab