From a77fcef3f99724e85e2239ad91683b7afe913cd8 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 3 Jul 2017 12:55:39 +0000 Subject: [PATCH] fix cuda compile error --- paddle/platform/cublas.h | 3 -- paddle/platform/cuda.h | 9 ++++++ paddle/platform/curand.h | 5 ++- paddle/platform/device_context.h | 52 +++++++++++++++++-------------- paddle/platform/dynamic_loader.cc | 4 +-- 5 files changed, 43 insertions(+), 30 deletions(-) diff --git a/paddle/platform/cublas.h b/paddle/platform/cublas.h index 70c97133252..d60eb501e9b 100644 --- a/paddle/platform/cublas.h +++ b/paddle/platform/cublas.h @@ -3,7 +3,6 @@ namespace paddle { namespace dyload { -namespace dynload { std::once_flag cublas_dso_flag; void *cublas_dso_handle = nullptr; @@ -67,8 +66,6 @@ CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP) #undef DYNAMIC_LOAD_CUBLAS_V2_WRAP #undef CUBLAS_BLAS_ROUTINE_EACH -} /* namespace dynload */ - // clang-format on #ifndef PADDLE_TYPE_DOUBLE #define CUBLAS_GEAM dynload::cublasSgeam diff --git a/paddle/platform/cuda.h b/paddle/platform/cuda.h index 8fe891f9ce6..05290b0e1e7 100644 --- a/paddle/platform/cuda.h +++ b/paddle/platform/cuda.h @@ -33,6 +33,15 @@ int GetDeviceCount(void) { throw_on_error(cudaGetDeviceCount(&count), "cudaGetDeviceCount failed"); return count; } +int GetCurrentDeviceId(void) { + int device_id; + throw_on_error(cudaGetDevice(&device_id), "cudaGetDevice failed"); + return device_id; +} + +void SetDeviceId(int device_id) { + throw_on_error(cudaSetDevice(device_id), "cudaSetDevice failed"); +} } // namespace platform } // namespace paddle diff --git a/paddle/platform/curand.h b/paddle/platform/curand.h index 692c024e6ec..edff6526bd8 100644 --- a/paddle/platform/curand.h +++ b/paddle/platform/curand.h @@ -3,6 +3,8 @@ namespace paddle { namespace dyload { +std::once_flag curand_dso_flag; +void *curand_dso_handle = nullptr; #ifdef PADDLE_USE_DSO #define DYNAMIC_LOAD_CURAND_WRAP(__name) \ struct DynLoad__##__name { \ @@ -31,7 +33,8 @@ namespace dyload { __macro(curandSetStream) \ __macro(curandSetPseudoRandomGeneratorSeed)\ __macro(curandGenerateUniform) \ - __macro(curandGenerateUniformDouble) + __macro(curandGenerateUniformDouble) \ + __macro(curandDestroyGenerator) // clang-format on CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP) diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index f95aac4a360..65e76666a79 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -83,11 +83,12 @@ class CudaDeviceContext : public DeviceContext { cublasHandle_t cublas_handle() { if (!blas_handle_) { DeviceGuard guard(gpu_place_); - PADDLE_ENFORCE(cublasCreate(&blas_handle_) == CUBLAS_STATUS_SUCCESS, - "cublasCreate failed"); PADDLE_ENFORCE( - cublasSetStream(blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS, - "cublasSetStream failed"); + paddle::dyload::cublasCreate(&blas_handle_) == CUBLAS_STATUS_SUCCESS, + "cublasCreate failed"); + PADDLE_ENFORCE(paddle::dyload::cublasSetStream(blas_handle_, stream_) == + CUBLAS_STATUS_SUCCESS, + "cublasSetStream failed"); } return blas_handle_; } @@ -95,11 +96,12 @@ class CudaDeviceContext : public DeviceContext { cudnnHandle_t cudnn_handle() { if (!dnn_handle_) { DeviceGuard guard(gpu_place_); - PADDLE_ENFORCE(cudnnCreate(&dnn_handle_) == CUDNN_STATUS_SUCCESS, - "cudnnCreate failed"); PADDLE_ENFORCE( - cudnnSetStream(dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS, - "cudnnSetStream failed"); + paddle::dyload::cudnnCreate(&dnn_handle_) == CUDNN_STATUS_SUCCESS, + "cudnnCreate failed"); + PADDLE_ENFORCE(paddle::dyload::cudnnSetStream(dnn_handle_, stream_) == + CUDNN_STATUS_SUCCESS, + "cudnnSetStream failed"); } return dnn_handle_; } @@ -107,17 +109,17 @@ class CudaDeviceContext : public DeviceContext { curandGenerator_t curand_generator() { if (!rand_generator_) { DeviceGuard guard(gpu_place_); + PADDLE_ENFORCE(paddle::dyload::curandCreateGenerator( + &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == + CURAND_STATUS_SUCCESS, + "curandCreateGenerator failed"); PADDLE_ENFORCE( - curandCreateGenerator(&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == - CURAND_STATUS_SUCCESS, - "curandCreateGenerator failed"); - PADDLE_ENFORCE( - curandSetPseudoRandomGeneratorSeed(rand_generator_, random_seed_) == - CURAND_STATUS_SUCCESS, + paddle::dyload::curandSetPseudoRandomGeneratorSeed( + rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS, "curandSetPseudoRandomGeneratorSeed failed"); - PADDLE_ENFORCE( - curandSetStream(rand_generator_, stream_) == CURAND_STATUS_SUCCESS, - "curandSetStream failed"); + PADDLE_ENFORCE(paddle::dyload::curandSetStream( + rand_generator_, stream_) == CURAND_STATUS_SUCCESS, + "curandSetStream failed"); } return rand_generator_; } @@ -125,19 +127,21 @@ class CudaDeviceContext : public DeviceContext { ~CudaDeviceContext() { Wait(); if (blas_handle_) { - PADDLE_ENFORCE(cublasDestroy(blas_handle_) == CUBLAS_STATUS_SUCCESS, - "cublasDestroy failed"); + PADDLE_ENFORCE( + paddle::dyload::cublasDestroy(blas_handle_) == CUBLAS_STATUS_SUCCESS, + "cublasDestroy failed"); } if (dnn_handle_) { - PADDLE_ENFORCE(cudnnDestroy(dnn_handle_) == CUDNN_STATUS_SUCCESS, - "cudnnDestroy failed"); + PADDLE_ENFORCE( + paddle::dyload::cudnnDestroy(dnn_handle_) == CUDNN_STATUS_SUCCESS, + "cudnnDestroy failed"); } if (rand_generator_) { - PADDLE_ENFORCE( - curandDestroyGenerator(rand_generator_) == CURAND_STATUS_SUCCESS, - "curandDestroyGenerator failed"); + PADDLE_ENFORCE(paddle::dyload::curandDestroyGenerator(rand_generator_) == + CURAND_STATUS_SUCCESS, + "curandDestroyGenerator failed"); } delete eigen_stream_; diff --git a/paddle/platform/dynamic_loader.cc b/paddle/platform/dynamic_loader.cc index 9036eaf6426..c34abc392c4 100644 --- a/paddle/platform/dynamic_loader.cc +++ b/paddle/platform/dynamic_loader.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "dynamic_loader.h" #include -#include "DynamicLoader.h" -#include "Logging.h" +#include DEFINE_string(cudnn_dir, "", "Specify path for loading libcudnn.so. For instance, " -- GitLab