diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index dfebf36d04ce5be928162346525be5e4039de3c0..485d89aa562fe85ac549c9a967ab321b94e3fcfe 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -153,7 +153,6 @@ void ParallelExecutor::BCastParamsToGPUs( if (main_var == nullptr || !main_var->IsType()) { continue; } - VLOG(3) << "run broadcast " << var << " " << var_dev_id; auto &main_tensor = main_var->Get(); auto &dims = main_tensor.dims(); @@ -184,8 +183,16 @@ void ParallelExecutor::BCastParamsToGPUs( platform::NCCLGroupGuard guard; for (size_t i = 0; i < member_->places_.size(); ++i) { auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]); - platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, - nccl_ctx.comm_, nccl_ctx.stream()); + if (initializing) { + platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, + nccl_ctx.comm_, nccl_ctx.stream()); + } else { + if (static_cast(var_dev_id)) { + platform::dynload::ncclBcast(buffers[i], numel, data_type, + var_dev_id, nccl_ctx.comm_, + nccl_ctx.stream()); + } + } } member_->nccl_ctxs_->WaitAll(); }