提交 5cc83f79 编写于 作者: Y Yancey1989

update by comment

上级 82726402
......@@ -110,23 +110,30 @@ ParallelExecutor::ParallelExecutor(
// Bcast Parameters to all GPUs
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
std::unique_ptr<ncclUniqueId> nccl_id = nullptr;
ncclUniqueId *nccl_id = nullptr;
bool need_group_call = true;
if (nccl_id_var != nullptr) {
nccl_id.reset(nccl_id_var->GetMutable<ncclUniqueId>());
} else if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) {
nccl_id.reset(new ncclUniqueId());
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id.get()));
*member_->global_scope_->Var(NCCL_ID_VARNAME)
->GetMutable<ncclUniqueId>() = *nccl_id.get();
if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) {
// parallel graph mode should initialize nccl by ncclCommInitRank since
// it call nccl operator per device per thread.
if (nccl_id_var == nullptr) {
nccl_id = new ncclUniqueId();
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id));
*member_->global_scope_->Var(NCCL_ID_VARNAME)
->GetMutable<ncclUniqueId>() = *nccl_id;
} else {
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
}
need_group_call = false;
} else if (nccl_id_var != nullptr) { // the other executor type.
// the distributed training with nccl mode would initialize the nccl id in
// startup_program.
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
} else {
// init nccl_id in NCCLContextMap
// initlize NCCL by ncclCommInitAll, do not need nccl_id.
}
member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
member_->places_, nccl_id.get(), num_trainers, trainer_id,
need_group_call));
member_->places_, nccl_id, num_trainers, trainer_id, need_group_call));
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册