From 2ab2869c2d5d37c201f08e57a787693aa0b3df93 Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Wed, 17 Apr 2019 09:19:54 +0800 Subject: [PATCH] fix GPU compile error problem --- paddle/fluid/framework/fleet/nccl_wrapper.cc | 12 ++++++++++++ paddle/fluid/framework/fleet/nccl_wrapper.h | 4 ++++ 2 files changed, 16 insertions(+) diff --git a/paddle/fluid/framework/fleet/nccl_wrapper.cc b/paddle/fluid/framework/fleet/nccl_wrapper.cc index 0df6aca8b12..051f4b013c6 100644 --- a/paddle/fluid/framework/fleet/nccl_wrapper.cc +++ b/paddle/fluid/framework/fleet/nccl_wrapper.cc @@ -24,33 +24,43 @@ std::shared_ptr NCCLWrapper::s_instance_ = NULL; bool NCCLWrapper::is_initialized_ = false; void NCCLWrapper::InitNCCL() { +#ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE(platform::dynload::ncclCommInitRank( &(nccl_info_.comm_), nccl_info_.global_ranks_, nccl_info_.nccl_id_, nccl_info_.my_global_rank_)); +#endif return; } void NCCLWrapper::SetNCCLId(const NCCLInfo& nccl_info) { +#ifdef PADDLE_WITH_CUDA nccl_info_.nccl_id_ = nccl_info.nccl_id_; +#endif + return; } NCCLInfo NCCLWrapper::GetNCCLId() { +#ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(&(nccl_info_.nccl_id_))); +#endif return nccl_info_; } void NCCLWrapper::SetRankInfo(const int local_rank, const int global_rank, const int ranks) { +#ifdef PADDLE_WITH_CUDA nccl_info_.local_rank_ = local_rank; nccl_info_.my_global_rank_ = global_rank; nccl_info_.global_ranks_ = ranks; PADDLE_ENFORCE(cudaSetDevice(local_rank)); PADDLE_ENFORCE(cudaStreamCreate(&(nccl_info_.stream_))); +#endif return; } void NCCLWrapper::SyncVar(const int root_rank, const Scope& scope, const std::vector& var_names) { +#ifdef PADDLE_WITH_CUDA for (auto& name : var_names) { auto var = scope.FindVar(name); LoDTensor* tensor = var->GetMutable(); @@ -60,6 +70,8 @@ void NCCLWrapper::SyncVar(const int root_rank, const Scope& scope, root_rank, nccl_info_.comm_, nccl_info_.stream_)); cudaStreamSynchronize(nccl_info_.stream_); } +#endif + return; } } // end namespace framework diff --git a/paddle/fluid/framework/fleet/nccl_wrapper.h b/paddle/fluid/framework/fleet/nccl_wrapper.h index eb4e5e19a3a..f29aa225419 100644 --- a/paddle/fluid/framework/fleet/nccl_wrapper.h +++ b/paddle/fluid/framework/fleet/nccl_wrapper.h @@ -24,7 +24,9 @@ limitations under the License. */ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/variable_helper.h" +#ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/dynload/nccl.h" +#endif #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN namespace paddle { @@ -39,9 +41,11 @@ class NCCLInfo { int local_rank_; int global_ranks_; int my_global_rank_; +#ifdef PADDLE_WITH_CUDA ncclUniqueId nccl_id_; ncclComm_t comm_; cudaStream_t stream_; +#endif }; class NCCLWrapper { -- GitLab