提交 5cc83f79 编写于 作者: Y Yancey1989

update by comment

上级 82726402
...@@ -110,23 +110,30 @@ ParallelExecutor::ParallelExecutor( ...@@ -110,23 +110,30 @@ ParallelExecutor::ParallelExecutor(
// Bcast Parameters to all GPUs // Bcast Parameters to all GPUs
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
std::unique_ptr<ncclUniqueId> nccl_id = nullptr; ncclUniqueId *nccl_id = nullptr;
bool need_group_call = true; bool need_group_call = true;
if (nccl_id_var != nullptr) { if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) {
nccl_id.reset(nccl_id_var->GetMutable<ncclUniqueId>()); // parallel graph mode should initialize nccl by ncclCommInitRank since
} else if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) { // it call nccl operator per device per thread.
nccl_id.reset(new ncclUniqueId()); if (nccl_id_var == nullptr) {
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id.get())); nccl_id = new ncclUniqueId();
*member_->global_scope_->Var(NCCL_ID_VARNAME) PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id));
->GetMutable<ncclUniqueId>() = *nccl_id.get(); *member_->global_scope_->Var(NCCL_ID_VARNAME)
->GetMutable<ncclUniqueId>() = *nccl_id;
} else {
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
}
need_group_call = false; need_group_call = false;
} else if (nccl_id_var != nullptr) { // the other executor type.
// the distributed training with nccl mode would initialize the nccl id in
// startup_program.
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
} else { } else {
// init nccl_id in NCCLContextMap // initlize NCCL by ncclCommInitAll, do not need nccl_id.
} }
member_->nccl_ctxs_.reset(new platform::NCCLContextMap( member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
member_->places_, nccl_id.get(), num_trainers, trainer_id, member_->places_, nccl_id, num_trainers, trainer_id, need_group_call));
need_group_call));
#else #else
PADDLE_THROW("Not compiled with CUDA"); PADDLE_THROW("Not compiled with CUDA");
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册