diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index a584b3a708be30278a0c9ad047c727364424f9ef..b6d1ee50739eb4388aeda51783232c9f59cf83d7 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -21,15 +21,14 @@ namespace details { AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, - std::vector> &&graphs) + std::unique_ptr &&graph) : strategy_(std::move(strategy)), local_scopes_(std::move(local_scopes)), pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), places_(std::move(places)), - graphs_(std::move(graphs)) { + graph_(std::move(graph)) { VLOG(3) << "build AsyncSSAGraphExecutor"; PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); - PADDLE_ENFORCE_EQ(graphs_.size(), local_scopes_.size()); // set the correct size of thread pool to each device. strategy_.num_threads_ = strategy_.num_threads_ < places_.size() @@ -39,7 +38,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( << " to run the operators of the graph on each device."; for (size_t i = 0; i < places.size(); ++i) { executors_.emplace_back(new details::ThreadedSSAGraphExecutor( - strategy_, {local_scopes_[i]}, {places_[i]}, std::move(graphs_[i]))); + strategy_, {local_scopes_[i]}, {places_[i]}, graph_.get())); } } diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.h b/paddle/fluid/framework/details/async_ssa_graph_executor.h index 4091c56d743c6ede4ea95076a91590427bc05a14..50f207361fb1c4579d4c86b09c019914882f2f5c 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.h @@ -29,9 +29,9 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { AsyncSSAGraphExecutor(const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, - std::vector> &&graphs); + std::unique_ptr &&graph); ~AsyncSSAGraphExecutor() final = default; - const ir::Graph &Graph() const override { return *graphs_[0]; } + const ir::Graph &Graph() const override { return *graph_; } FeedFetchList Run(const std::vector &fetch_tensors) override; @@ -40,7 +40,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { std::vector local_scopes_; std::unique_ptr<::ThreadPool> pool_{nullptr}; std::vector places_; - std::vector> graphs_; + std::unique_ptr graph_; std::vector> executors_; ExceptionHolder exception_holder_; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 8236773672562e04bc525f6e320f6c1784bee228..129d3a7f0d3580e3df0645239ca015d261e83f94 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -264,71 +264,59 @@ ParallelExecutor::ParallelExecutor( // Step 2. Convert main_program to SSA form and dependency graph. Also, insert // ncclOp - std::vector> graphs; + std::unique_ptr graph; #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { VLOG(3) << "use local async mode"; - for (size_t i = 0; i < member_->places_.size(); ++i) { - std::unique_ptr graph = build_strategy.Apply( - main_program, {member_->places_[i]}, loss_var_name, - {member_->local_scopes_[i]}, member_->nranks_, member_->use_cuda_, - member_->nccl_ctxs_.get()); - graphs.push_back(std::move(graph)); - } + graph = + build_strategy.Apply(main_program, {member_->places_[0]}, loss_var_name, + {member_->local_scopes_[0]}, member_->nranks_, + member_->use_cuda_, member_->nccl_ctxs_.get()); } else { - std::unique_ptr graph = build_strategy.Apply( - main_program, member_->places_, loss_var_name, member_->local_scopes_, - member_->nranks_, member_->use_cuda_, member_->nccl_ctxs_.get()); - graphs.push_back(std::move(graph)); + graph = build_strategy.Apply(main_program, member_->places_, loss_var_name, + member_->local_scopes_, member_->nranks_, + member_->use_cuda_, member_->nccl_ctxs_.get()); } #else if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { VLOG(3) << "use local async mode"; - for (size_t i = 0; i < member_->places_.size(); ++i) { - std::unique_ptr graph = build_strategy.Apply( - main_program, {member_->places_[i]}, loss_var_name, - {member_->local_scopes_[i]}, member_->nranks_, member_->use_cuda_); - graphs.push_back(std::move(graph)); - } + graph = build_strategy.Apply(main_program, {member_->places_[0]}, + loss_var_name, {member_->local_scopes_[0]}, + member_->nranks_, member_->use_cuda_); } else { - std::unique_ptr graph = build_strategy.Apply( - main_program, member_->places_, loss_var_name, member_->local_scopes_, - member_->nranks_, member_->use_cuda_); - graphs.push_back(std::move(graph)); + graph = build_strategy.Apply(main_program, member_->places_, loss_var_name, + member_->local_scopes_, member_->nranks_, + member_->use_cuda_); } #endif auto max_memory_size = GetEagerDeletionThreshold(); VLOG(10) << "Eager Deletion Threshold " << static_cast(max_memory_size) / (1 << 30); if (max_memory_size >= 0) { - for (size_t i = 0; i < graphs.size(); ++i) { - graphs[i] = member_->PrepareGCAndRefCnts( - std::move(graphs[i]), static_cast(max_memory_size)); - } + graph = member_->PrepareGCAndRefCnts(std::move(graph), + static_cast(max_memory_size)); } // Step 3. Create vars in each scope. Passes may also create new vars. // skip control vars and empty vars std::vector var_infos; - for (auto &graph : graphs) { - for (auto &node : graph->Nodes()) { - if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { - var_infos.emplace_back(); - var_infos.back().name_ = node->Var()->Name(); - var_infos.back().type_ = node->Var()->GetType(); - var_infos.back().persistable_ = node->Var()->Persistable(); - } + for (auto &node : graph->Nodes()) { + if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { + var_infos.emplace_back(); + var_infos.back().name_ = node->Var()->Name(); + var_infos.back().type_ = node->Var()->GetType(); + var_infos.back().persistable_ = node->Var()->Persistable(); } } // If the loss_var_name is given, the number of graph should be only one. if (loss_var_name.size()) { - size_t graph_num = ir::GraphNum(*graphs[0]); + size_t graph_num = ir::GraphNum(*graph); if (graph_num > 1) { LOG(WARNING) << "The number of graph should be only one, " "but the current graph has " - << ir::GraphNum(*graphs[0]) + << ir::GraphNum(*graph) << " sub_graphs. If you want to see the nodes of the " "sub_graphs, you should use 'FLAGS_print_sub_graph_dir' " "to specify the output dir. NOTES: if you not do training, " @@ -340,7 +328,7 @@ ParallelExecutor::ParallelExecutor( VLOG(3) << "use AsyncSSAGraphExecutor"; member_->executor_.reset(new details::AsyncSSAGraphExecutor( exec_strategy, member_->local_scopes_, member_->places_, - std::move(graphs))); + std::move(graph))); } else if (build_strategy.enable_parallel_graph_) { VLOG(3) << "use ParallelSSAGraphExecutor"; #ifdef PADDLE_WITH_CUDA @@ -358,12 +346,12 @@ ParallelExecutor::ParallelExecutor( VLOG(3) << "use ThreadedSSAGraphExecutor"; member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, member_->places_, - std::move(graphs[0]))); + std::move(graph))); } else { VLOG(3) << "use FastThreadedSSAGraphExecutor"; member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, member_->places_, - std::move(graphs[0]))); + std::move(graph))); } }