diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 3fab6adf0f87a016c917e271776de0f89ab2c101..b27647a8eebcf43b4331f946662af59596253ff0 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -37,8 +37,9 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( } } -void MultiDevSSAGraphBuilder::Build(const ProgramDesc &program, - SSAGraph *graph) const { +std::unique_ptr MultiDevSSAGraphBuilder::Build( + const ProgramDesc &program) const { + auto graph = new SSAGraph(); SSAGraph &result = *graph; result.vars_.resize(places_.size()); @@ -134,6 +135,8 @@ void MultiDevSSAGraphBuilder::Build(const ProgramDesc &program, harzaeds need to be handled. */ PolishGraphToSupportDataHazards(&result); + + return std::unique_ptr(graph); } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 510f85bc877da623cb7849c24a9bdf0e58785fe5..17959a94d6cf7762982778bac35da624d1cb7436 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -32,7 +32,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::vector &local_scopes, platform::NCCLContextMap *nccl_ctxs); - void Build(const ProgramDesc &program, SSAGraph *graph) const override; + std::unique_ptr Build(const ProgramDesc &program) const override; private: std::string loss_var_name_; diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 848b90293a3a43dc747b8fad3f52ad43d69c1014..df05bb739421672fd8b169b45246b9a75f77ca7f 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -18,6 +18,7 @@ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/place.h" +#include #include namespace paddle { @@ -28,7 +29,7 @@ class SSAGraphBuilder { public: SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {} - virtual void Build(const ProgramDesc &program, SSAGraph *graph) const = 0; + virtual std::unique_ptr Build(const ProgramDesc &program) const = 0; DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 78ef66be5141b33eaf0dbda0953b1b34a041e222..88070a06a255733198ebada586cc84d78cafe626 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -34,16 +34,16 @@ class SSAGraphExecutor { DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor); public: - explicit SSAGraphExecutor(SSAGraph *graph) : graph_(*graph) {} + // Steal graph inside + explicit SSAGraphExecutor(std::unique_ptr &&graph) + : graph_(std::move(graph)) {} virtual ~SSAGraphExecutor() {} - virtual void Run(Scope *global_scope, - const std::vector &fetch_tensors, - const std::string &fetch_list_name) = 0; + virtual FeedFetchList Run(const std::vector &fetch_tensors) = 0; protected: - SSAGraph &graph_; + std::unique_ptr graph_; }; class ThreadedSSAGraphExecutor : public SSAGraphExecutor { @@ -51,16 +51,17 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ThreadedSSAGraphExecutor(size_t num_threads, bool use_event, const std::vector &local_scopes, const std::vector &places, - SSAGraph *graph) - : SSAGraphExecutor(graph), + std::unique_ptr &&graph) + : SSAGraphExecutor(std::move(graph)), pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr), local_scopes_(local_scopes), places_(places), fetch_ctxs_(places), use_event_(use_event) {} - void Run(Scope *global_scope, const std::vector &fetch_tensors, - const std::string &fetch_list_name) override { + // Run a SSAGraph by a thread pool + // Use topological sort algorithm + FeedFetchList Run(const std::vector &fetch_tensors) override { std::unordered_map pending_ops; std::unordered_map> pending_vars; std::unordered_set ready_ops; @@ -74,18 +75,18 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { }; // Transform SSAGraph to pending_ops & pending_vars - for (auto &var_map : graph_.vars_) { + for (auto &var_map : graph_->vars_) { for (auto &name_pair : var_map) { for (auto &version_pair : name_pair.second) { InsertPendingVar(version_pair.second); } } } - for (auto &var : graph_.dep_vars_) { + for (auto &var : graph_->dep_vars_) { InsertPendingVar(*var); } - for (auto &op : graph_.ops_) { + for (auto &op : graph_->ops_) { if (op->inputs_.empty()) { // Special case, Op has no input. ready_ops.insert(op.get()); } else { @@ -101,7 +102,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { std::unordered_map> fetched_vars; for (auto &fetch_var_name : fetch_tensors) { - for (auto &var_map : graph_.vars_) { + for (auto &var_map : graph_->vars_) { auto it = var_map.find(fetch_var_name); if (it != var_map.end()) { fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second); @@ -182,8 +183,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { fetch_op.WaitAndMergeCPUTensors(); } - *global_scope->Var(fetch_list_name)->GetMutable() = - fetch_data; + return fetch_data; } ~ThreadedSSAGraphExecutor() {} @@ -240,8 +240,6 @@ class ParallelExecutorPrivate { std::unique_ptr nccl_ctxs_; - details::SSAGraph graph_; - std::unique_ptr executor_; }; @@ -274,10 +272,10 @@ ParallelExecutor::ParallelExecutor( details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, params, member_->local_scopes_, member_->nccl_ctxs_.get()); - builder.Build(main_program, &member_->graph_); + auto graph = builder.Build(main_program); member_->executor_.reset(new ThreadedSSAGraphExecutor( - num_threads, true, member_->local_scopes_, places, &member_->graph_)); + num_threads, true, member_->local_scopes_, places, std::move(graph))); // Step 3. Create vars in each scope; for (auto *scope : member_->local_scopes_) { @@ -338,8 +336,9 @@ void ParallelExecutor::BuildNCCLCommunicator() const { void ParallelExecutor::Run(const std::vector &fetch_tensors, const std::string &fetched_var_name) { - member_->executor_->Run(member_->global_scope_, fetch_tensors, - fetched_var_name); + auto fetch_data = member_->executor_->Run(fetch_tensors); + *member_->global_scope_->Var(fetched_var_name)->GetMutable() = + fetch_data; } } // namespace framework