From a0466053073eae411175e19de610dbe7575ad1d7 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 11 Jul 2017 15:47:42 +0800 Subject: [PATCH] Refine CUDA Related libraries --- paddle/platform/CMakeLists.txt | 10 +- paddle/platform/cuda.h | 6 +- paddle/platform/device_context.cc | 13 +++ paddle/platform/dynload/CMakeLists.txt | 1 + paddle/platform/dynload/cublas.cc | 15 +++ paddle/platform/dynload/cublas.h | 89 +++++++---------- paddle/platform/dynload/cudnn.cc | 28 ++++++ paddle/platform/dynload/cudnn.h | 129 ++++++++++++------------- paddle/platform/dynload/curand.cc | 15 +++ paddle/platform/dynload/curand.h | 45 ++++----- 10 files changed, 201 insertions(+), 150 deletions(-) create mode 100644 paddle/platform/device_context.cc create mode 100644 paddle/platform/dynload/cublas.cc create mode 100644 paddle/platform/dynload/cudnn.cc create mode 100644 paddle/platform/dynload/curand.cc diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index ebacd5d6dc8..7a198aec6cf 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -1,8 +1,14 @@ add_subdirectory(dynload) -nv_test(cuda_test SRCS cuda_test.cu) +nv_test(cuda_test SRCS cuda_test.cu DEPS dyload_cuda) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) +IF(WITH_GPU) + set(GPU_CTX_DEPS dyload_cuda dynamic_loader ) +ELSE() + set(GPU_CTX_DEPS) +ENDIF() -nv_test(device_context_test SRCS device_context_test.cc DEPS dynamic_loader place eigen3 glog gflags) +cc_library(device_context SRCS device_context.cc DEPS place eigen3 ${GPU_CTX_DEPS}) +nv_test(device_context_test SRCS device_context_test.cc DEPS device_context glog gflags) diff --git a/paddle/platform/cuda.h b/paddle/platform/cuda.h index 5ed36c0f025..96889abf9eb 100644 --- a/paddle/platform/cuda.h +++ b/paddle/platform/cuda.h @@ -28,19 +28,19 @@ inline void throw_on_error(cudaError_t e, const char* message) { } } -int GetDeviceCount(void) { +inline int GetDeviceCount(void) { int count; throw_on_error(cudaGetDeviceCount(&count), "cudaGetDeviceCount failed"); return count; } -int GetCurrentDeviceId(void) { +inline int GetCurrentDeviceId(void) { int device_id; throw_on_error(cudaGetDevice(&device_id), "cudaGetDevice failed"); return device_id; } -void SetDeviceId(int device_id) { +inline void SetDeviceId(int device_id) { throw_on_error(cudaSetDevice(device_id), "cudaSetDevice failed"); } diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc new file mode 100644 index 00000000000..a2dea2ed1e1 --- /dev/null +++ b/paddle/platform/device_context.cc @@ -0,0 +1,13 @@ +#include + +namespace paddle { +namespace platform { +namespace dynload { +namespace dummy { +// Make DeviceContext A library. +int DUMMY_VAR_FOR_DEV_CTX = 0; + +} // namespace dummy +} // namespace dynload +} // namespace platform +} // namespace paddle \ No newline at end of file diff --git a/paddle/platform/dynload/CMakeLists.txt b/paddle/platform/dynload/CMakeLists.txt index 9f829b70128..4a8866b3d36 100644 --- a/paddle/platform/dynload/CMakeLists.txt +++ b/paddle/platform/dynload/CMakeLists.txt @@ -1 +1,2 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags) +nv_library(dyload_cuda SRCS cublas.cc cudnn.cc curand.cc) diff --git a/paddle/platform/dynload/cublas.cc b/paddle/platform/dynload/cublas.cc new file mode 100644 index 00000000000..f83fcf34d74 --- /dev/null +++ b/paddle/platform/dynload/cublas.cc @@ -0,0 +1,15 @@ +#include + +namespace paddle { +namespace platform { +namespace dynload { +std::once_flag cublas_dso_flag; +void *cublas_dso_handle = nullptr; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name; + +CUBLAS_BLAS_ROUTINE_EACH(DEFINE_WRAP); + +} // namespace dynload +} // namespace platform +} // namespace paddle \ No newline at end of file diff --git a/paddle/platform/dynload/cublas.h b/paddle/platform/dynload/cublas.h index 258cc88031a..1332be31b13 100644 --- a/paddle/platform/dynload/cublas.h +++ b/paddle/platform/dynload/cublas.h @@ -23,8 +23,8 @@ namespace paddle { namespace platform { namespace dynload { -std::once_flag cublas_dso_flag; -void *cublas_dso_handle = nullptr; +extern std::once_flag cublas_dso_flag; +extern void *cublas_dso_handle; /** * The following macro definition can generate structs @@ -34,10 +34,10 @@ void *cublas_dso_handle = nullptr; * note: default dynamic linked libs */ #ifdef PADDLE_USE_DSO -#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ +#define DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ struct DynLoad__##__name { \ template \ - cublasStatus_t operator()(Args... args) { \ + inline cublasStatus_t operator()(Args... args) { \ typedef cublasStatus_t (*cublasFunc)(Args...); \ std::call_once(cublas_dso_flag, \ paddle::platform::dynload::GetCublasDsoHandle, \ @@ -45,62 +45,43 @@ void *cublas_dso_handle = nullptr; void *p_##__name = dlsym(cublas_dso_handle, #__name); \ return reinterpret_cast(p_##__name)(args...); \ } \ - } __name; // struct DynLoad__##__name + }; \ + extern DynLoad__##__name __name #else -#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - cublasStatus_t operator()(Args... args) { \ - return __name(args...); \ - } \ - } __name; // struct DynLoad__##__name +#define DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ + struct DynLoad__##__name { \ + inline template \ + cublasStatus_t operator()(Args... args) { \ + return __name(args...); \ + } \ + }; \ + extern DynLoad__##__name __name #endif -#define DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) DYNAMIC_LOAD_CUBLAS_WRAP(__name) - -// include all needed cublas functions in HPPL -// clang-format off #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ - __macro(cublasSgemv) \ - __macro(cublasDgemv) \ - __macro(cublasSgemm) \ - __macro(cublasDgemm) \ - __macro(cublasSgeam) \ - __macro(cublasDgeam) \ - -DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasCreate) -DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasDestroy) -DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetStream) -DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetPointerMode) -DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasGetPointerMode) -DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched) -DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched) -DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched) -DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched) -DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetrfBatched) -DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetriBatched) -DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetrfBatched) -DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetriBatched) -CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP) + __macro(cublasSgemv); \ + __macro(cublasDgemv); \ + __macro(cublasSgemm); \ + __macro(cublasDgemm); \ + __macro(cublasSgeam); \ + __macro(cublasDgeam); \ + __macro(cublasCreate); \ + __macro(cublasDestroy); \ + __macro(cublasSetStream); \ + __macro(cublasSetPointerMode); \ + __macro(cublasGetPointerMode); \ + __macro(cublasSgemmBatched); \ + __macro(cublasDgemmBatched); \ + __macro(cublasCgemmBatched); \ + __macro(cublasZgemmBatched); \ + __macro(cublasSgetrfBatched); \ + __macro(cublasSgetriBatched); \ + __macro(cublasDgetrfBatched); \ + __macro(cublasDgetriBatched) -#undef DYNAMIC_LOAD_CUBLAS_WRAP -#undef DYNAMIC_LOAD_CUBLAS_V2_WRAP -#undef CUBLAS_BLAS_ROUTINE_EACH +CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP); -// clang-format on -#ifndef PADDLE_TYPE_DOUBLE -#define CUBLAS_GEAM paddle::platform::dynload::cublasSgeam -#define CUBLAS_GEMV paddle::platform::dynload::cublasSgemv -#define CUBLAS_GEMM paddle::platform::dynload::cublasSgemm -#define CUBLAS_GETRF paddle::platform::dynload::cublasSgetrfBatched -#define CUBLAS_GETRI paddle::platform::dynload::cublasSgetriBatched -#else -#define CUBLAS_GEAM paddle::platform::dynload::cublasDgeam -#define CUBLAS_GEMV paddle::platform::dynload::cublasDgemv -#define CUBLAS_GEMM paddle::platform::dynload::cublasDgemm -#define CUBLAS_GETRF paddle::platform::dynload::cublasDgetrfBatched -#define CUBLAS_GETRI paddle::platform::dynload::cublasDgetriBatched -#endif +#undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/platform/dynload/cudnn.cc b/paddle/platform/dynload/cudnn.cc new file mode 100644 index 00000000000..8b5e15b5efc --- /dev/null +++ b/paddle/platform/dynload/cudnn.cc @@ -0,0 +1,28 @@ +#include + +namespace paddle { +namespace platform { +namespace dynload { +std::once_flag cudnn_dso_flag; +void* cudnn_dso_handle = nullptr; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +CUDNN_DNN_ROUTINE_EACH(DEFINE_WRAP); +CUDNN_DNN_ROUTINE_EACH_R2(DEFINE_WRAP); + +#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R3 +CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DEFINE_WRAP); +#endif + +#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R4 +CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DEFINE_WRAP); +#endif + +#ifdef CUDNN_DNN_ROUTINE_EACH_R5 +CUDNN_DNN_ROUTINE_EACH_R5(DEFINE_WRAP); +#endif + +} // namespace dynload +} // namespace platform +} // namespace paddle \ No newline at end of file diff --git a/paddle/platform/dynload/cudnn.h b/paddle/platform/dynload/cudnn.h index 0a9562c573c..ef0dd85b083 100644 --- a/paddle/platform/dynload/cudnn.h +++ b/paddle/platform/dynload/cudnn.h @@ -23,12 +23,12 @@ namespace paddle { namespace platform { namespace dynload { -std::once_flag cudnn_dso_flag; -void* cudnn_dso_handle = nullptr; +extern std::once_flag cudnn_dso_flag; +extern void* cudnn_dso_handle; #ifdef PADDLE_USE_DSO -#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \ +#define DECLARE_DYNAMIC_LOAD_CUDNN_WRAP(__name) \ struct DynLoad__##__name { \ template \ auto operator()(Args... args) -> decltype(__name(args...)) { \ @@ -39,17 +39,19 @@ void* cudnn_dso_handle = nullptr; void* p_##__name = dlsym(cudnn_dso_handle, #__name); \ return reinterpret_cast(p_##__name)(args...); \ } \ - } __name; /* struct DynLoad__##__name */ + }; \ + extern struct DynLoad__##__name __name #else -#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \ +#define DECLARE_DYNAMIC_LOAD_CUDNN_WRAP(__name) \ struct DynLoad__##__name { \ template \ auto operator()(Args... args) -> decltype(__name(args...)) { \ return __name(args...); \ } \ - } __name; /* struct DynLoad__##__name */ + }; \ + extern DynLoad__##__name __name #endif @@ -57,80 +59,73 @@ void* cudnn_dso_handle = nullptr; * include all needed cudnn functions in HPPL * different cudnn version has different interfaces **/ -// clang-format off -#define CUDNN_DNN_ROUTINE_EACH(__macro) \ - __macro(cudnnSetTensor4dDescriptor) \ - __macro(cudnnSetTensor4dDescriptorEx) \ - __macro(cudnnGetConvolutionNdForwardOutputDim) \ - __macro(cudnnGetConvolutionForwardAlgorithm) \ - __macro(cudnnCreateTensorDescriptor) \ - __macro(cudnnDestroyTensorDescriptor) \ - __macro(cudnnCreateFilterDescriptor) \ - __macro(cudnnSetFilter4dDescriptor) \ - __macro(cudnnSetPooling2dDescriptor) \ - __macro(cudnnDestroyFilterDescriptor) \ - __macro(cudnnCreateConvolutionDescriptor) \ - __macro(cudnnCreatePoolingDescriptor) \ - __macro(cudnnDestroyPoolingDescriptor) \ - __macro(cudnnSetConvolution2dDescriptor) \ - __macro(cudnnDestroyConvolutionDescriptor) \ - __macro(cudnnCreate) \ - __macro(cudnnDestroy) \ - __macro(cudnnSetStream) \ - __macro(cudnnActivationForward) \ - __macro(cudnnConvolutionForward) \ - __macro(cudnnConvolutionBackwardBias) \ - __macro(cudnnGetConvolutionForwardWorkspaceSize) \ - __macro(cudnnTransformTensor) \ - __macro(cudnnPoolingForward) \ - __macro(cudnnPoolingBackward) \ - __macro(cudnnSoftmaxBackward) \ - __macro(cudnnSoftmaxForward) \ - __macro(cudnnGetVersion) \ - __macro(cudnnGetErrorString) -CUDNN_DNN_ROUTINE_EACH(DYNAMIC_LOAD_CUDNN_WRAP) - -#define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \ - __macro(cudnnAddTensor) \ - __macro(cudnnConvolutionBackwardData) \ - __macro(cudnnConvolutionBackwardFilter) -CUDNN_DNN_ROUTINE_EACH_R2(DYNAMIC_LOAD_CUDNN_WRAP) +#define CUDNN_DNN_ROUTINE_EACH(__macro) \ + __macro(cudnnSetTensor4dDescriptor); \ + __macro(cudnnSetTensor4dDescriptorEx); \ + __macro(cudnnGetConvolutionNdForwardOutputDim); \ + __macro(cudnnGetConvolutionForwardAlgorithm); \ + __macro(cudnnCreateTensorDescriptor); \ + __macro(cudnnDestroyTensorDescriptor); \ + __macro(cudnnCreateFilterDescriptor); \ + __macro(cudnnSetFilter4dDescriptor); \ + __macro(cudnnSetPooling2dDescriptor); \ + __macro(cudnnDestroyFilterDescriptor); \ + __macro(cudnnCreateConvolutionDescriptor); \ + __macro(cudnnCreatePoolingDescriptor); \ + __macro(cudnnDestroyPoolingDescriptor); \ + __macro(cudnnSetConvolution2dDescriptor); \ + __macro(cudnnDestroyConvolutionDescriptor); \ + __macro(cudnnCreate); \ + __macro(cudnnDestroy); \ + __macro(cudnnSetStream); \ + __macro(cudnnActivationForward); \ + __macro(cudnnConvolutionForward); \ + __macro(cudnnConvolutionBackwardBias); \ + __macro(cudnnGetConvolutionForwardWorkspaceSize); \ + __macro(cudnnTransformTensor); \ + __macro(cudnnPoolingForward); \ + __macro(cudnnPoolingBackward); \ + __macro(cudnnSoftmaxBackward); \ + __macro(cudnnSoftmaxForward); \ + __macro(cudnnGetVersion); \ + __macro(cudnnGetErrorString); +CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) + +#define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \ + __macro(cudnnAddTensor); \ + __macro(cudnnConvolutionBackwardData); \ + __macro(cudnnConvolutionBackwardFilter); +CUDNN_DNN_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) // APIs available after R3: #if CUDNN_VERSION >= 3000 -#define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro) \ - __macro(cudnnGetConvolutionBackwardFilterWorkspaceSize) \ - __macro(cudnnGetConvolutionBackwardDataAlgorithm) \ - __macro(cudnnGetConvolutionBackwardFilterAlgorithm) \ - __macro(cudnnGetConvolutionBackwardDataWorkspaceSize) -CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DYNAMIC_LOAD_CUDNN_WRAP) -#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3 +#define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro) \ + __macro(cudnnGetConvolutionBackwardFilterWorkspaceSize); \ + __macro(cudnnGetConvolutionBackwardDataAlgorithm); \ + __macro(cudnnGetConvolutionBackwardFilterAlgorithm); \ + __macro(cudnnGetConvolutionBackwardDataWorkspaceSize); +CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif - // APIs available after R4: #if CUDNN_VERSION >= 4007 -#define CUDNN_DNN_ROUTINE_EACH_AFTER_R4(__macro) \ - __macro(cudnnBatchNormalizationForwardTraining) \ - __macro(cudnnBatchNormalizationForwardInference) \ - __macro(cudnnBatchNormalizationBackward) -CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DYNAMIC_LOAD_CUDNN_WRAP) -#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R4 +#define CUDNN_DNN_ROUTINE_EACH_AFTER_R4(__macro) \ + __macro(cudnnBatchNormalizationForwardTraining); \ + __macro(cudnnBatchNormalizationForwardInference); \ + __macro(cudnnBatchNormalizationBackward); +CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif // APIs in R5 #if CUDNN_VERSION >= 5000 -#define CUDNN_DNN_ROUTINE_EACH_R5(__macro) \ - __macro(cudnnCreateActivationDescriptor) \ - __macro(cudnnSetActivationDescriptor) \ - __macro(cudnnGetActivationDescriptor) \ - __macro(cudnnDestroyActivationDescriptor) -CUDNN_DNN_ROUTINE_EACH_R5(DYNAMIC_LOAD_CUDNN_WRAP) -#undef CUDNN_DNN_ROUTINE_EACH_R5 +#define CUDNN_DNN_ROUTINE_EACH_R5(__macro) \ + __macro(cudnnCreateActivationDescriptor); \ + __macro(cudnnSetActivationDescriptor); \ + __macro(cudnnGetActivationDescriptor); \ + __macro(cudnnDestroyActivationDescriptor); +CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif -#undef CUDNN_DNN_ROUTINE_EACH -// clang-format on } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/platform/dynload/curand.cc b/paddle/platform/dynload/curand.cc new file mode 100644 index 00000000000..5c1fab992c9 --- /dev/null +++ b/paddle/platform/dynload/curand.cc @@ -0,0 +1,15 @@ +#include + +namespace paddle { +namespace platform { +namespace dynload { + +std::once_flag curand_dso_flag; +void *curand_dso_handle; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +CURAND_RAND_ROUTINE_EACH(DEFINE_WRAP); +} +} +} \ No newline at end of file diff --git a/paddle/platform/dynload/curand.h b/paddle/platform/dynload/curand.h index 9dc0a25c0fb..d8c46bc41e1 100644 --- a/paddle/platform/dynload/curand.h +++ b/paddle/platform/dynload/curand.h @@ -22,10 +22,10 @@ limitations under the License. */ namespace paddle { namespace platform { namespace dynload { -std::once_flag curand_dso_flag; -void *curand_dso_handle = nullptr; +extern std::once_flag curand_dso_flag; +extern void *curand_dso_handle; #ifdef PADDLE_USE_DSO -#define DYNAMIC_LOAD_CURAND_WRAP(__name) \ +#define DECLARE_DYNAMIC_LOAD_CURAND_WRAP(__name) \ struct DynLoad__##__name { \ template \ curandStatus_t operator()(Args... args) { \ @@ -36,32 +36,29 @@ void *curand_dso_handle = nullptr; void *p_##__name = dlsym(curand_dso_handle, #__name); \ return reinterpret_cast(p_##__name)(args...); \ } \ - } __name; /* struct DynLoad__##__name */ + }; \ + extern DynLoad__##__name __name #else -#define DYNAMIC_LOAD_CURAND_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - curandStatus_t operator()(Args... args) { \ - return __name(args...); \ - } \ - } __name; /* struct DynLoad__##__name */ +#define DECLARE_DYNAMIC_LOAD_CURAND_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + curandStatus_t operator()(Args... args) { \ + return __name(args...); \ + } \ + }; \ + extern DynLoad__##__name __name #endif -/* include all needed curand functions in HPPL */ -// clang-format off -#define CURAND_RAND_ROUTINE_EACH(__macro) \ - __macro(curandCreateGenerator) \ - __macro(curandSetStream) \ - __macro(curandSetPseudoRandomGeneratorSeed)\ - __macro(curandGenerateUniform) \ - __macro(curandGenerateUniformDouble) \ - __macro(curandDestroyGenerator) -// clang-format on +#define CURAND_RAND_ROUTINE_EACH(__macro) \ + __macro(curandCreateGenerator); \ + __macro(curandSetStream); \ + __macro(curandSetPseudoRandomGeneratorSeed); \ + __macro(curandGenerateUniform); \ + __macro(curandGenerateUniformDouble); \ + __macro(curandDestroyGenerator); -CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP) +CURAND_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CURAND_WRAP); -#undef CURAND_RAND_ROUTINE_EACH -#undef DYNAMIC_LOAD_CURAND_WRAP } // namespace dynload } // namespace platform } // namespace paddle -- GitLab