diff --git a/paddle/fluid/framework/details/all_reduce_deps_pass.h b/paddle/fluid/framework/details/all_reduce_deps_pass.h index 1637c7a7a65a73556b2d546dc382985d4888386d..e8b91089816c71bc56ba7dba0105e85d73eb52ad 100644 --- a/paddle/fluid/framework/details/all_reduce_deps_pass.h +++ b/paddle/fluid/framework/details/all_reduce_deps_pass.h @@ -21,8 +21,6 @@ namespace paddle { namespace framework { namespace details { -constexpr char kAllOpDescs[] = "all_op_descs"; - // TODO(gongwb): overlap allreduce with backward computation. class AllReduceDepsPass : public ir::Pass { protected: diff --git a/paddle/fluid/framework/details/memory_optimize_helper.h b/paddle/fluid/framework/details/memory_optimize_helper.h index 0bfaf827fea84030de48a9984197f5b39f5c9261..2c9a16d445564ed7fd21d07e9dbb346bce7db590 100644 --- a/paddle/fluid/framework/details/memory_optimize_helper.h +++ b/paddle/fluid/framework/details/memory_optimize_helper.h @@ -29,8 +29,6 @@ namespace paddle { namespace framework { namespace details { -constexpr char kAllOpDescs[] = "all_op_descs"; - std::vector SortOpLikeDescOrder(const ir::Graph& graph); // NOTE(dzh): A ordered set for node reuse in memory optimize. diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 4f856c6d9eb842add9eb5e6fe30639dc4170358d..27bc7718147da54f8bd09600cbf9692d3839ffdd 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -221,7 +221,6 @@ std::unique_ptr MultiDevSSAGraphBuilderBase::ApplyImpl( * Only variables should be the leaves of graph. */ AddOutputToLeafOps(&result); - // result.Erase(kGraphOps); return graph; } diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc index 3433c3424e43e6cf714fca995e6ccd5a9589a9c2..2cafa1873ad690260528aaff0c8f6684384d31b0 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc @@ -19,12 +19,12 @@ namespace paddle { namespace framework { namespace details { -std::vector> SeparateMultiDevicesGraph( - const std::vector &places, - std::unique_ptr graph) { +std::vector> +ParallelSSAGraphExecutor::SeparateMultiDevicesGraph( + std::unique_ptr &&graph) { std::vector> graphs; - graphs.reserve(places.size()); - for (size_t i = 0; i < places.size(); ++i) { + graphs.reserve(places_.size()); + for (size_t i = 0; i < places_.size(); ++i) { ProgramDesc empty; graphs.emplace_back(std::unique_ptr(new ir::Graph(empty))); auto &g = graphs.back(); @@ -60,7 +60,7 @@ std::vector> SeparateMultiDevicesGraph( } } - for (size_t dev_id = 0; dev_id < places.size(); ++dev_id) { + for (size_t dev_id = 0; dev_id < places_.size(); ++dev_id) { auto &dev_vars = graphs[dev_id]->Get(kGraphVars)[0]; auto &origin_vars = graph->Get(kGraphVars)[dev_id]; for (auto &name_pair : origin_vars) { @@ -80,14 +80,26 @@ std::vector> SeparateMultiDevicesGraph( ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, - std::vector> &&graphs) + const framework::ProgramDesc &main_prog, std::unique_ptr &&graph) : strategy_(std::move(strategy)), local_scopes_(std::move(local_scopes)), pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), places_(std::move(places)), - graphs_(std::move(graphs)) { + main_prog_(main_prog), + // TODO(Yancey1989): copy graphs is not safely since it deleted the attrs. + graphs_(SeparateMultiDevicesGraph(std::move(graph))) { PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); + auto seq_allreduce_pass = + ir::PassRegistry::Instance().Get("all_reduce_deps_pass"); + seq_allreduce_pass->Erase(details::kAllOpDescs); + seq_allreduce_pass->Set>( + details::kAllOpDescs, + new std::vector(main_prog_.Block(0).AllOps())); + for (size_t i = 0; i < graphs_.size(); ++i) { + graphs_[i] = seq_allreduce_pass->Apply(std::move(graphs_[i])); + } + // set the correct size of thread pool to each device. strategy_.num_threads_ = strategy_.num_threads_ < places_.size() ? 1UL diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.h b/paddle/fluid/framework/details/parallel_ssa_graph_executor.h index c31bba17f6840019660376991145028b9c254933..f59305bf9827cfc6032eda355d884dec825dc351 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.h @@ -28,16 +28,13 @@ namespace paddle { namespace framework { namespace details { -std::vector> SeparateMultiDevicesGraph( - const std::vector &places, - std::unique_ptr graph); - class ParallelSSAGraphExecutor : public SSAGraphExecutor { public: ParallelSSAGraphExecutor(const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, - std::vector> &&graphs); + const framework::ProgramDesc &main_prog, + std::unique_ptr &&graph); ~ParallelSSAGraphExecutor() final = default; const ir::Graph &Graph() const override { return *graphs_[0]; } @@ -45,10 +42,14 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { FeedFetchList Run(const std::vector &fetch_tensors) override; private: + std::vector> SeparateMultiDevicesGraph( + std::unique_ptr &&graph); + ExecutionStrategy strategy_; std::vector local_scopes_; std::unique_ptr<::ThreadPool> pool_{nullptr}; std::vector places_; + framework::ProgramDesc main_prog_; std::vector> graphs_; std::vector> executors_; diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index b55a77451371b2dc3764eedea33126204b5d0997..d5b3782f622a8d5addc7eb51cebf7c8fbac3a453 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -26,6 +26,11 @@ limitations under the License. */ namespace paddle { namespace framework { + +namespace details { +constexpr char kAllOpDescs[] = "all_op_descs"; +} // namespace details + namespace ir { /* diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index dbe1bf9b2929726ea28b91dad7dd59616dfdbd21..56da5660095affa0ba49d8bc533d1da01ffd18be 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -305,21 +305,11 @@ ParallelExecutor::ParallelExecutor( if (build_strategy.enable_parallel_graph_) { #ifdef PADDLE_WITH_CUDA - auto parallel_graph = - details::SeparateMultiDevicesGraph(member_->places_, std::move(graph)); - auto seq_allreduce_pass = - ir::PassRegistry::Instance().Get("all_reduce_deps_pass"); - seq_allreduce_pass->Erase(details::kAllOpDescs); - seq_allreduce_pass->Set>( - details::kAllOpDescs, - new std::vector(main_program.Block(0).AllOps())); - for (size_t i = 0; i < parallel_graph.size(); ++i) { - parallel_graph[i] = - seq_allreduce_pass->Apply(std::move(parallel_graph[i])); - } + // TODO(Yancey1989): Remove passing in the main_program when + // allreduce_seq_pass doesn't need it as the attr. member_->executor_.reset(new details::ParallelSSAGraphExecutor( - exec_strategy, member_->local_scopes_, member_->places_, - std::move(parallel_graph))); + exec_strategy, member_->local_scopes_, member_->places_, main_program, + std::move(graph))); #else PADDLE_THROW( "Paddle should be compiled with CUDA for ParallelGraph Execution.");