diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 250e093a5f789dba6b06df4889c060c294d469fe..8f319116ab80b75c624f35b0e1315e7362e88d9a 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -127,6 +127,9 @@ static const char kLocalScopes[] = "local_scopes"; static const char kStrategy[] = "strategy"; void MultiDevSSAGraphBuilder::Init() const { + all_vars_.clear(); + balance_vars_.clear(); + loss_var_name_ = Get(kLossVarName); places_ = Get>(kPlaces); local_scopes_ = Get>(kLocalScopes); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 1ca8c4b855f9468589e537245380451a91a50b14..47aaa80f4d66a48b729d0638badcab885a50585c 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -40,12 +40,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass { size_t device_id) const; void Init() const; - private: - mutable std::string loss_var_name_; - mutable std::vector places_; - mutable std::vector local_scopes_; - mutable std::unordered_set grad_names_; - #ifdef PADDLE_WITH_CUDA mutable platform::NCCLContextMap *nccl_ctxs_; #endif @@ -95,13 +89,17 @@ class MultiDevSSAGraphBuilder : public ir::Pass { size_t GetAppropriateDeviceID( const std::vector &var_names) const; - private: + void SetCommunicationContext(OpHandleBase *op_handle, + const platform::Place &p) const; + + mutable std::string loss_var_name_; + mutable std::vector places_; + mutable std::vector local_scopes_; + mutable std::unordered_set grad_names_; + mutable BuildStrategy strategy_; mutable std::unordered_map all_vars_; mutable std::vector balance_vars_; - - void SetCommunicationContext(OpHandleBase *op_handle, - const platform::Place &p) const; }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index ae393d66a3b3ec0141667b44b5d9f3158e434e37..dbc3ff8657a1f2238951a791fb5ac3356c885770 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -233,30 +233,9 @@ ParallelExecutor::ParallelExecutor( void ParallelExecutor::BCastParamsToDevices( const std::unordered_set &vars) const { - // the initializing bcast, all vars would be bcast from device(0), - // otherwise - // bcast from the specified device. - bool initializing = member_->executor_ ? false : true; + // the initializing bcast, all vars would be bcast from device(0). for (auto &var : vars) { - int var_dev_id = -1; - if (member_->executor_) { - auto &sharded_var_device = - member_->executor_->Graph().Get( - details::kShardedVarDevice); - if (sharded_var_device.find(var) != sharded_var_device.end()) { - var_dev_id = sharded_var_device.at(var); - } - } - - if (!initializing && var_dev_id == -1) continue; - - framework::Variable *main_var = nullptr; - if (initializing) { - main_var = member_->local_scopes_[0]->FindVar(var); - } else { - main_var = member_->local_scopes_[var_dev_id]->FindVar(var); - } - + framework::Variable *main_var = member_->local_scopes_[0]->FindVar(var); if (main_var == nullptr || !main_var->IsType()) { continue; } @@ -272,8 +251,7 @@ void ParallelExecutor::BCastParamsToDevices( auto place = member_->places_[i]; void *buffer; - if ((initializing && i == 0) || - (!initializing && static_cast(i) == var_dev_id)) { + if (i == 0) { buffer = const_cast(main_tensor.data()); } else { auto local_scope = member_->local_scopes_[i]; @@ -290,29 +268,18 @@ void ParallelExecutor::BCastParamsToDevices( platform::NCCLGroupGuard guard; for (size_t i = 0; i < member_->places_.size(); ++i) { auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]); - if (initializing) { - platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, - nccl_ctx.comm_, nccl_ctx.stream()); - } else { - if (var_dev_id >= 0) { - platform::dynload::ncclBcast(buffers[i], numel, data_type, - var_dev_id, nccl_ctx.comm_, - nccl_ctx.stream()); - } - } + platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, + nccl_ctx.comm_, nccl_ctx.stream()); } member_->nccl_ctxs_->WaitAll(); } - #else PADDLE_THROW("Not compiled with CUDA"); #endif } else { platform::CPUPlace cpu; for (size_t i = 0; i < member_->places_.size(); ++i) { - if ((initializing && i == 0) || - (!initializing && static_cast(i) == var_dev_id)) - continue; + if (i == 0) continue; auto local_scope = member_->local_scopes_[i]; auto *t = local_scope->Var(var)->GetMutable(); diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 88e2078454024c3a4d437615d3e6b15ee0c7d6a1..c64906ff230df5f2b7cc9f5c6b29d68956ab8f33 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -72,9 +72,9 @@ class ParallelExecutor { void Run(const std::vector &fetch_tensors, const std::string &fetched_var_name); + private: void BCastParamsToDevices(const std::unordered_set &vars) const; - private: ParallelExecutorPrivate *member_; #ifdef PADDLE_WITH_CUDA