From 9605fcd124ae6a3cdad171d2d61107e9cabe4c2f Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Thu, 12 Jul 2018 11:26:08 +0800 Subject: [PATCH] all graphs --- paddle/fluid/framework/details/ssa_graph_checker.h | 1 - paddle/fluid/framework/details/ssa_graph_printer.h | 2 +- .../details/threaded_ssa_graph_executor.cc | 13 +++++++------ .../framework/details/threaded_ssa_graph_executor.h | 5 +++-- paddle/fluid/framework/parallel_executor.cc | 8 +------- 5 files changed, 12 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index 7078b778bea..2c8b2e13c5c 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -21,7 +21,6 @@ namespace paddle { namespace framework { namespace details { -struct SSAGraph; class SSAGraghBuilderWithChecker : public SSAGraphBuilder { public: diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index 0bd2b10eda4..35f2a1b4f0e 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -21,7 +21,7 @@ namespace paddle { namespace framework { namespace details { -struct SSAGraph; + class SSAGraphPrinter { public: virtual ~SSAGraphPrinter() {} diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 07097c7e75c..ed8e38039ed 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -14,13 +14,14 @@ #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" +#include "paddle/fluid/framework/details/ssa_graph_builder.h" + namespace paddle { namespace framework { namespace details { ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, - const std::vector &places, - std::unique_ptr &&graph) + const std::vector &places, std::unique_ptr &&graph) : graph_(std::move(graph)), pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_) : nullptr), @@ -43,18 +44,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( std::unordered_set delayed_ops; // Transform SSAGraph to pending_ops & pending_vars - for (auto &var_map : graph_->vars_) { + for (auto &var_map : graph_->Get("vars")) { for (auto &name_pair : var_map) { for (auto &version_pair : name_pair.second) { InsertPendingVar(&pending_vars, &ready_vars, version_pair.get()); } } } - for (auto &var : graph_->dep_vars_) { + for (auto &var : graph_->Get("dep_vars")) { InsertPendingVar(&pending_vars, &ready_vars, var.get()); } - for (auto &op : graph_->ops_) { + for (auto &op : graph_->Get("ops")) { if (op->Inputs().empty()) { // Special case, Op has no input. ready_ops.insert(op.get()); } else { @@ -158,7 +159,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( std::unordered_map> fetched_vars; for (auto &fetch_var_name : fetch_tensors) { - for (auto &var_map : graph_->vars_) { + for (auto &var_map : graph_->Get("vars")) { auto it = var_map.find(fetch_var_name); if (it != var_map.end()) { fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 09973b7a728..7d0aaf2ddc0 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -27,6 +27,7 @@ #include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h" +#include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace framework { @@ -39,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, - std::unique_ptr &&graph); + std::unique_ptr &&graph); // Run a SSAGraph by a thread pool // Use topological sort algorithm @@ -52,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { details::OpHandleBase *op); private: - std::unique_ptr graph_; + std::unique_ptr graph_; std::unique_ptr<::ThreadPool> pool_; std::vector local_scopes_; std::vector places_; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 42bbd2b3ff4..d30aba07a01 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -135,14 +135,8 @@ ParallelExecutor::ParallelExecutor( builder_ = builder_factory.Create(); std::unique_ptr graph = builder_->Build(ProgramToGraph(main_program)); - std::unique_ptr ssa_graph(new details::SSAGraph); - ssa_graph->vars_ = std::move(graph->Get("vars")); - ssa_graph->ops_ = std::move(graph->Get("ops")); - ssa_graph->dep_vars_ = - std::move(graph->Get("dep_vars")); - member_->executor_.reset(new details::ThreadedSSAGraphExecutor( - exec_strategy, member_->local_scopes_, places, std::move(ssa_graph))); + exec_strategy, member_->local_scopes_, places, std::move(graph))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos), -- GitLab