提交 6f010712 编写于 作者: Y yi.wu

fix broadcast bug

上级 88cb47bd
...@@ -153,7 +153,6 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -153,7 +153,6 @@ void ParallelExecutor::BCastParamsToGPUs(
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) { if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
continue; continue;
} }
VLOG(3) << "run broadcast " << var << " " << var_dev_id;
auto &main_tensor = main_var->Get<LoDTensor>(); auto &main_tensor = main_var->Get<LoDTensor>();
auto &dims = main_tensor.dims(); auto &dims = main_tensor.dims();
...@@ -184,8 +183,16 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -184,8 +183,16 @@ void ParallelExecutor::BCastParamsToGPUs(
platform::NCCLGroupGuard guard; platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]); auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]);
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, if (initializing) {
nccl_ctx.comm_, nccl_ctx.stream()); platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());
} else {
if (static_cast<size_t>(var_dev_id)) {
platform::dynload::ncclBcast(buffers[i], numel, data_type,
var_dev_id, nccl_ctx.comm_,
nccl_ctx.stream());
}
}
} }
member_->nccl_ctxs_->WaitAll(); member_->nccl_ctxs_->WaitAll();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册