提交 2ab2869c 编写于 作者: D dongdaxiang

fix GPU compile error problem

上级 466d177d
......@@ -24,33 +24,43 @@ std::shared_ptr<NCCLWrapper> NCCLWrapper::s_instance_ = NULL;
bool NCCLWrapper::is_initialized_ = false;
void NCCLWrapper::InitNCCL() {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::dynload::ncclCommInitRank(
&(nccl_info_.comm_), nccl_info_.global_ranks_, nccl_info_.nccl_id_,
nccl_info_.my_global_rank_));
#endif
return;
}
void NCCLWrapper::SetNCCLId(const NCCLInfo& nccl_info) {
#ifdef PADDLE_WITH_CUDA
nccl_info_.nccl_id_ = nccl_info.nccl_id_;
#endif
return;
}
NCCLInfo NCCLWrapper::GetNCCLId() {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(&(nccl_info_.nccl_id_)));
#endif
return nccl_info_;
}
void NCCLWrapper::SetRankInfo(const int local_rank, const int global_rank,
const int ranks) {
#ifdef PADDLE_WITH_CUDA
nccl_info_.local_rank_ = local_rank;
nccl_info_.my_global_rank_ = global_rank;
nccl_info_.global_ranks_ = ranks;
PADDLE_ENFORCE(cudaSetDevice(local_rank));
PADDLE_ENFORCE(cudaStreamCreate(&(nccl_info_.stream_)));
#endif
return;
}
void NCCLWrapper::SyncVar(const int root_rank, const Scope& scope,
const std::vector<std::string>& var_names) {
#ifdef PADDLE_WITH_CUDA
for (auto& name : var_names) {
auto var = scope.FindVar(name);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
......@@ -60,6 +70,8 @@ void NCCLWrapper::SyncVar(const int root_rank, const Scope& scope,
root_rank, nccl_info_.comm_, nccl_info_.stream_));
cudaStreamSynchronize(nccl_info_.stream_);
}
#endif
return;
}
} // end namespace framework
......
......@@ -24,7 +24,9 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
......@@ -39,9 +41,11 @@ class NCCLInfo {
int local_rank_;
int global_ranks_;
int my_global_rank_;
#ifdef PADDLE_WITH_CUDA
ncclUniqueId nccl_id_;
ncclComm_t comm_;
cudaStream_t stream_;
#endif
};
class NCCLWrapper {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册