From 8dec4ad7a1c37b705b584e64c3eef4d6df320c13 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 21 Mar 2018 17:12:27 +0800 Subject: [PATCH] Use int not Place for vars --- paddle/fluid/framework/parallel_executor.cc | 46 ++++++++++----------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 440040a2ef..d3919f0d51 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -28,6 +28,7 @@ limitations under the License. */ namespace paddle { namespace framework { +using details::ComputationOpHandle; using details::DummyVarHandle; using details::FetchOpHandle; using details::NCCLAllReduceOpHandle; @@ -35,7 +36,6 @@ using details::OpHandleBase; using details::ScaleLossGradOpHandle; using details::VarHandle; using details::VarHandleBase; -using details::ComputationOpHandle; class ParallelExecutorPrivate { public: @@ -43,7 +43,9 @@ class ParallelExecutorPrivate { const std::vector &places) : places_(places), fetch_dev_ctxs_(places), - pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) {} + pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) { + vars_.resize(places.size()); + } std::vector places_; platform::DeviceContextPool fetch_dev_ctxs_; @@ -52,12 +54,7 @@ class ParallelExecutorPrivate { std::unique_ptr nccl_ctxs_; - platform::Place main_place_; - - std::unordered_map>, - platform::PlaceHash> - vars_; + std::vector>> vars_; std::unordered_set> dep_vars_; @@ -69,8 +66,8 @@ class ParallelExecutorPrivate { std::unique_ptr exception_; VarHandle *GetVarHandle(const std::string &each_var_name, - const platform::Place &place) { - auto &var_holders = vars_[place]; + const platform::Place &place, size_t place_offset) { + auto &var_holders = vars_[place_offset]; auto &var_holder = var_holders[each_var_name]; VarHandle *var = nullptr; if (var_holder.empty()) { @@ -118,8 +115,8 @@ class ParallelExecutorPrivate { } void GenerateVar(OpHandleBase *op_handle, const std::string &each_var_name, - const platform::Place &place) { - auto &vars = vars_[place][each_var_name]; + const platform::Place &place, size_t place_offset) { + auto &vars = vars_[place_offset][each_var_name]; size_t version = vars.size(); auto &var = vars[version]; var.version_ = version; @@ -144,11 +141,10 @@ ParallelExecutor::ParallelExecutor( for (size_t i = 0; i < member_->places_.size(); ++i) { member_->local_scopes_.push_back(&scope->NewScope()); } - member_->main_place_ = places[0]; // Bcast Parameters to all GPUs BuildNCCLCommunicator(); - if (platform::is_gpu_place(member_->main_place_) && + if (platform::is_gpu_place(places[0]) && member_->local_scopes_.size() != 1) { // Is CUDA BCastParamsToGPUs(startup_program); } @@ -201,13 +197,13 @@ void ParallelExecutor::ConstructDependencyGraph( auto var_names = op->InputArgumentNames(); for (auto &each_var_name : var_names) { - VarHandle *var = member_->GetVarHandle(each_var_name, p); + VarHandle *var = member_->GetVarHandle(each_var_name, p, i); op_handle->AddInput(var); } var_names = op->OutputArgumentNames(); for (auto &each_var_name : var_names) { - member_->GenerateVar(op_handle, each_var_name, p); + member_->GenerateVar(op_handle, each_var_name, p, i); } if (is_forwarding) { @@ -224,7 +220,7 @@ void ParallelExecutor::ConstructDependencyGraph( // loss->pending_ops_.emplace_back(op_handle); // op_handle->inputs_.emplace_back(loss); - member_->GenerateVar(op_handle, loss_var_name + "@GRAD", p); + member_->GenerateVar(op_handle, loss_var_name + "@GRAD", p, i); change_forward = true; } } @@ -245,7 +241,7 @@ void ParallelExecutor::ConstructDependencyGraph( for (size_t i = 0; i < member_->places_.size(); ++i) { auto &p = member_->places_[i]; - auto &vars = member_->vars_[p][og]; + auto &vars = member_->vars_[i][og]; if (vars.empty()) { // This device has no data. continue. continue; @@ -280,8 +276,8 @@ void ParallelExecutor::ConstructDependencyGraph( * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) */ void ParallelExecutor::PolishGraphToSupportDataHazards() const { - for (auto &place_pair : member_->vars_) { - for (auto &name_pair : place_pair.second) { + for (auto &var_map : member_->vars_) { + for (auto &name_pair : var_map) { if (name_pair.second.size() <= 1) { return; } @@ -369,8 +365,8 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, std::unordered_map pending_ops; std::vector dummy_vars; - for (auto &place_pair : member_->vars_) { - for (auto &name_pair : place_pair.second) { + for (auto &var_map : member_->vars_) { + for (auto &name_pair : var_map) { for (auto &version_pair : name_pair.second) { pending_vars[&version_pair.second] = version_pair.second.generated_op_ == nullptr; @@ -395,9 +391,9 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, std::unordered_map> fetched_vars; for (auto &fetch_var_name : fetch_tensors) { - for (auto &pair : member_->vars_) { - auto it = pair.second.find(fetch_var_name); - if (it != pair.second.end()) { + for (auto &var_map : member_->vars_) { + auto it = var_map.find(fetch_var_name); + if (it != var_map.end()) { fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second); } } -- GitLab