diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 215ee38ac585093f3d0defa5e9bf91add1355e66..996273c720a2e821ac37842d066a09f84ad95311 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -171,27 +171,28 @@ class ParallelExecutorPrivate { return boost::get(ctx_->GetPlace()).device; } - static void InitNCCLContext(std::map &contexts) { + static void InitNCCLContext(std::unordered_map &contexts, + const std::vector &places) { std::vector comms; std::vector devs; comms.resize(contexts.size()); devs.reserve(contexts.size()); - for (auto &ctx : contexts) { - devs.push_back(ctx.first); + for (auto &p : places) { + devs.push_back(boost::get(p).device); } NCCL_INVOKE(platform::dynload::ncclCommInitAll( &comms[0], static_cast(contexts.size()), &devs[0])); int i = 0; - for (auto &ctx : contexts) { - ctx.second.comm = comms[i++]; + for (auto &dev_id : devs) { + contexts.at(dev_id).comm = comms[i++]; } } }; - std::map communication_streams_; + std::unordered_map communication_streams_; NCCLContext &GetNCCLCtx(platform::Place p) { int dev_id = boost::get(p).device; @@ -493,13 +494,20 @@ void ParallelExecutor::BCastParamsToGPUs( platform::dynload::ncclGroupStart(); - for (auto &place : member_->places_) { - auto local_scope = member_->local_scopes_[place]; - auto *t = local_scope->Var(var_desc->Name())->GetMutable(); - t->Resize(dims); + for (size_t i = 0; i < member_->places_.size(); ++i) { + auto place = member_->places_[i]; + void *buffer; + if (i == 0) { + buffer = const_cast(main_tensor.data()); + } else { + auto local_scope = member_->local_scopes_[place]; + auto *t = local_scope->Var(var_desc->Name())->GetMutable(); + t->Resize(dims); + buffer = t->mutable_data(place, main_tensor.type()); + } + auto &nccl_ctx = member_->GetNCCLCtx(place); - platform::dynload::ncclBcast(t->mutable_data(place, main_tensor.type()), - numel, data_type, 0, nccl_ctx.comm, + platform::dynload::ncclBcast(buffer, numel, data_type, 0, nccl_ctx.comm, nccl_ctx.stream()); } platform::dynload::ncclGroupEnd(); @@ -533,7 +541,7 @@ void ParallelExecutor::BuildNCCLCommunicator() const { } ParallelExecutorPrivate::NCCLContext::InitNCCLContext( - member_->communication_streams_); + member_->communication_streams_, member_->places_); #endif }