From d470763f6c0e7641367641bdb6cb1f28b8cf39c3 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 16 Mar 2018 15:53:36 +0800 Subject: [PATCH] Stash --- paddle/fluid/framework/parallel_executor.cc | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index ea5ce3f2e9c..215ee38ac58 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -154,6 +154,8 @@ class ParallelExecutorPrivate { std::unordered_map local_scopes_; + std::vector places_; + #ifdef PADDLE_WITH_CUDA struct NCCLContext { std::unique_ptr ctx_; @@ -246,6 +248,8 @@ ParallelExecutor::ParallelExecutor( const ProgramDesc &startup_program, const ProgramDesc &main_program, const std::string &loss_var_name, Scope *scope) : member_(new ParallelExecutorPrivate()) { + member_->places_ = places; + // Step 1. RunStartupProgram and Bcast the params to devs. Executor exe(places[0]); exe.Run(startup_program, scope, 0); @@ -489,14 +493,14 @@ void ParallelExecutor::BCastParamsToGPUs( platform::dynload::ncclGroupStart(); - for (auto &pair : member_->local_scopes_) { - auto local_scope = pair.second; + for (auto &place : member_->places_) { + auto local_scope = member_->local_scopes_[place]; auto *t = local_scope->Var(var_desc->Name())->GetMutable(); t->Resize(dims); - auto &nccl_ctx = member_->GetNCCLCtx(pair.first); - platform::dynload::ncclBcast( - t->mutable_data(pair.first, main_tensor.type()), numel, data_type, - 0, nccl_ctx.comm, nccl_ctx.stream()); + auto &nccl_ctx = member_->GetNCCLCtx(place); + platform::dynload::ncclBcast(t->mutable_data(place, main_tensor.type()), + numel, data_type, 0, nccl_ctx.comm, + nccl_ctx.stream()); } platform::dynload::ncclGroupEnd(); } @@ -506,7 +510,7 @@ void ParallelExecutor::BCastParamsToGPUs( for (auto &pair : member_->local_scopes_) { member_->GetNCCLCtx(pair.first).ctx_->Wait(); - auto &b = pair.second->FindVar("fc_1.b_0")->Get(); + auto &b = pair.second->FindVar("fc_0.b_0")->Get(); framework::LoDTensor cpu; framework::TensorCopy(b, platform::CPUPlace(), &cpu); platform::DeviceContextPool::Instance().Get(b.place())->Wait(); -- GitLab