提交 6f010712 编写于 作者: Y yi.wu

fix broadcast bug

上级 88cb47bd
......@@ -153,7 +153,6 @@ void ParallelExecutor::BCastParamsToGPUs(
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
continue;
}
VLOG(3) << "run broadcast " << var << " " << var_dev_id;
auto &main_tensor = main_var->Get<LoDTensor>();
auto &dims = main_tensor.dims();
......@@ -184,8 +183,16 @@ void ParallelExecutor::BCastParamsToGPUs(
platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->places_.size(); ++i) {
auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]);
if (initializing) {
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());
} else {
if (static_cast<size_t>(var_dev_id)) {
platform::dynload::ncclBcast(buffers[i], numel, data_type,
var_dev_id, nccl_ctx.comm_,
nccl_ctx.stream());
}
}
}
member_->nccl_ctxs_->WaitAll();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册