diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index d3919f0d51b5b97844ba841abf57294531aef39a..37bfdc0df5273c6eda3754cb766fa0dff3077248 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -37,6 +37,86 @@ using details::ScaleLossGradOpHandle; using details::VarHandle; using details::VarHandleBase; +struct SSAGraph { + std::vector>> vars_; + std::unordered_set> dep_vars_; + std::vector> ops_; +}; + +/** + * We only handle write after read(WAR), since it should not have a write + * after write in program. If there are write after write operators, we need + * prune them. + * + * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) + */ +static void PolishGraphToSupportDataHazards(SSAGraph *graph) { + for (auto &var_map : graph->vars_) { + for (auto &name_pair : var_map) { + if (name_pair.second.size() <= 1) { + return; + } + auto it_new = name_pair.second.rbegin(); + auto it_old = name_pair.second.rbegin(); + ++it_old; + for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { + auto *write_op = it_new->second.generated_op_; + auto &read_ops = it_old->second.pending_ops_; + auto *ex_write_op = it_old->second.generated_op_; + + if (ex_write_op == nullptr) { // Nobody write this var. + continue; + } + + for (auto *read_op : read_ops) { + // Manually add a dependency var from read_op to write_op; + if (read_op == write_op) { + // Read Write is the same op. + continue; + } + + auto *dep_var = new DummyVarHandle(); + read_op->AddOutput(dep_var); + write_op->AddInput(dep_var); + graph->dep_vars_.emplace(dep_var); + } + } + } + } +} + +static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph, + const std::string &each_var_name, + const platform::Place &place, + size_t place_offset) { + auto &var_holders = graph->vars_[place_offset]; + auto &var_holder = var_holders[each_var_name]; + VarHandle *var = nullptr; + if (var_holder.empty()) { + auto &init_var = var_holder[0]; + init_var.place_ = place; + init_var.name_ = each_var_name; + init_var.generated_op_ = nullptr; + init_var.version_ = 0; + var = &init_var; + } else { + var = &var_holder.rbegin()->second; + } + return var; +} + +static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, + const std::string &each_var_name, + const platform::Place &place, size_t place_offset) { + auto &vars = graph->vars_[place_offset][each_var_name]; + size_t version = vars.size(); + auto &var = vars[version]; + var.version_ = version; + var.name_ = each_var_name; + var.place_ = place; + op_handle->AddOutput(&var); +} + class ParallelExecutorPrivate { public: explicit ParallelExecutorPrivate(size_t num_threads, @@ -44,7 +124,7 @@ class ParallelExecutorPrivate { : places_(places), fetch_dev_ctxs_(places), pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) { - vars_.resize(places.size()); + graph_.vars_.resize(places.size()); } std::vector places_; @@ -54,35 +134,13 @@ class ParallelExecutorPrivate { std::unique_ptr nccl_ctxs_; - std::vector>> vars_; - - std::unordered_set> dep_vars_; - - std::vector> ops_; + SSAGraph graph_; // Use a simpler thread pool, might be faster. std::unique_ptr pool_; std::unique_ptr exception_; - VarHandle *GetVarHandle(const std::string &each_var_name, - const platform::Place &place, size_t place_offset) { - auto &var_holders = vars_[place_offset]; - auto &var_holder = var_holders[each_var_name]; - VarHandle *var = nullptr; - if (var_holder.empty()) { - auto &init_var = var_holder[0]; - init_var.place_ = place; - init_var.name_ = each_var_name; - init_var.generated_op_ = nullptr; - init_var.version_ = 0; - var = &init_var; - } else { - var = &var_holder.rbegin()->second; - } - return var; - } - void RunOp( bool use_event, std::unordered_map> &pending_vars, @@ -113,17 +171,6 @@ class ParallelExecutorPrivate { op_run(); } } - - void GenerateVar(OpHandleBase *op_handle, const std::string &each_var_name, - const platform::Place &place, size_t place_offset) { - auto &vars = vars_[place_offset][each_var_name]; - size_t version = vars.size(); - auto &var = vars[version]; - var.version_ = version; - var.name_ = each_var_name; - var.place_ = place; - op_handle->AddOutput(&var); - } }; ParallelExecutor::ParallelExecutor( @@ -189,21 +236,22 @@ void ParallelExecutor::ConstructDependencyGraph( auto &p = member_->places_[i]; auto *s = member_->local_scopes_[i]; - member_->ops_.emplace_back(new ComputationOpHandle(*op, s, p)); - auto *op_handle = member_->ops_.back().get(); + member_->graph_.ops_.emplace_back(new ComputationOpHandle(*op, s, p)); + auto *op_handle = member_->graph_.ops_.back().get(); op_handle->dev_ctx_[p] = const_cast( platform::DeviceContextPool::Instance().Get(p)); auto var_names = op->InputArgumentNames(); for (auto &each_var_name : var_names) { - VarHandle *var = member_->GetVarHandle(each_var_name, p, i); + VarHandle *var = + CreateOrGetLatestVarHandle(&member_->graph_, each_var_name, p, i); op_handle->AddInput(var); } var_names = op->OutputArgumentNames(); for (auto &each_var_name : var_names) { - member_->GenerateVar(op_handle, each_var_name, p, i); + CreateOpOutput(&member_->graph_, op_handle, each_var_name, p, i); } if (is_forwarding) { @@ -212,7 +260,7 @@ void ParallelExecutor::ConstructDependencyGraph( op_handle = new ScaleLossGradOpHandle(this->member_->local_scopes_.size(), s, p, member_->nccl_ctxs_->DevCtx(p)); - member_->ops_.emplace_back(op_handle); + member_->graph_.ops_.emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale // factor. So it does not depend on any other operators. @@ -220,7 +268,8 @@ void ParallelExecutor::ConstructDependencyGraph( // loss->pending_ops_.emplace_back(op_handle); // op_handle->inputs_.emplace_back(loss); - member_->GenerateVar(op_handle, loss_var_name + "@GRAD", p, i); + CreateOpOutput(&member_->graph_, op_handle, loss_var_name + "@GRAD", + p, i); change_forward = true; } } @@ -235,13 +284,13 @@ void ParallelExecutor::ConstructDependencyGraph( for (auto &og : var_names) { if (grads.count(og) != 0) { // is param grad // Insert NCCL AllReduce Op - member_->ops_.emplace_back(new NCCLAllReduceOpHandle( + member_->graph_.ops_.emplace_back(new NCCLAllReduceOpHandle( member_->local_scopes_, member_->places_, *member_->nccl_ctxs_)); - auto *op_handle = member_->ops_.back().get(); + auto *op_handle = member_->graph_.ops_.back().get(); for (size_t i = 0; i < member_->places_.size(); ++i) { auto &p = member_->places_[i]; - auto &vars = member_->vars_[i][og]; + auto &vars = member_->graph_.vars_[i][og]; if (vars.empty()) { // This device has no data. continue. continue; @@ -265,49 +314,7 @@ void ParallelExecutor::ConstructDependencyGraph( Dependency graph has been constructed. However, there are still data harzaeds need to be handled. */ - PolishGraphToSupportDataHazards(); -} - -/** - * We only handle write after read(WAR), since it should not have a write - * after write in program. If there are write after write operators, we need - * prune them. - * - * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) - */ -void ParallelExecutor::PolishGraphToSupportDataHazards() const { - for (auto &var_map : member_->vars_) { - for (auto &name_pair : var_map) { - if (name_pair.second.size() <= 1) { - return; - } - auto it_new = name_pair.second.rbegin(); - auto it_old = name_pair.second.rbegin(); - ++it_old; - for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { - auto *write_op = it_new->second.generated_op_; - auto &read_ops = it_old->second.pending_ops_; - auto *ex_write_op = it_old->second.generated_op_; - - if (ex_write_op == nullptr) { // Nobody write this var. - continue; - } - - for (auto *read_op : read_ops) { - // Manually add a dependency var from read_op to write_op; - if (read_op == write_op) { - // Read Write is the same op. - continue; - } - - auto *dep_var = new DummyVarHandle(); - read_op->AddOutput(dep_var); - write_op->AddInput(dep_var); - member_->dep_vars_.emplace(dep_var); - } - } - } - } + PolishGraphToSupportDataHazards(&member_->graph_); } void ParallelExecutor::BCastParamsToGPUs( @@ -365,7 +372,7 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, std::unordered_map pending_ops; std::vector dummy_vars; - for (auto &var_map : member_->vars_) { + for (auto &var_map : member_->graph_.vars_) { for (auto &name_pair : var_map) { for (auto &version_pair : name_pair.second) { pending_vars[&version_pair.second] = @@ -374,13 +381,13 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, } } - for (auto &var : member_->dep_vars_) { + for (auto &var : member_->graph_.dep_vars_) { pending_vars[var.get()] = var->generated_op_ == nullptr; } std::vector to_run; - for (auto &op : member_->ops_) { + for (auto &op : member_->graph_.ops_) { if (op->inputs_.empty()) { // Special case, Op has no input. to_run.emplace_back(op.get()); } else { @@ -391,7 +398,7 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, std::unordered_map> fetched_vars; for (auto &fetch_var_name : fetch_tensors) { - for (auto &var_map : member_->vars_) { + for (auto &var_map : member_->graph_.vars_) { auto it = var_map.find(fetch_var_name); if (it != var_map.end()) { fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second); diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 466b5f5f62d4d8ccbebb64b578911c67dfcd0709..8c91c45d1462f6f968631078931578a081409924 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -52,8 +52,6 @@ class ParallelExecutor { const std::string& loss_var_name) const; void BuildNCCLCommunicator() const; - - void PolishGraphToSupportDataHazards() const; }; } // namespace framework