diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index 7078b778bea9c181a69431f079ec1fecba788a42..2c8b2e13c5cc40ef799e2647fc50370363dcfe37 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -21,7 +21,6 @@ namespace paddle { namespace framework { namespace details { -struct SSAGraph; class SSAGraghBuilderWithChecker : public SSAGraphBuilder { public: diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index 0bd2b10eda45b01b45cca41de8eb4dd84f2828e9..35f2a1b4f0e3c478717007ea1a43e1e3d6820861 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -21,7 +21,7 @@ namespace paddle { namespace framework { namespace details { -struct SSAGraph; + class SSAGraphPrinter { public: virtual ~SSAGraphPrinter() {} diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 07097c7e75c6ce638549716cd6523f387cdefd92..ed8e38039edcf782cd5be97bee8197983719a7d8 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -14,13 +14,14 @@ #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" +#include "paddle/fluid/framework/details/ssa_graph_builder.h" + namespace paddle { namespace framework { namespace details { ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, - const std::vector &places, - std::unique_ptr &&graph) + const std::vector &places, std::unique_ptr &&graph) : graph_(std::move(graph)), pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_) : nullptr), @@ -43,18 +44,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( std::unordered_set delayed_ops; // Transform SSAGraph to pending_ops & pending_vars - for (auto &var_map : graph_->vars_) { + for (auto &var_map : graph_->Get("vars")) { for (auto &name_pair : var_map) { for (auto &version_pair : name_pair.second) { InsertPendingVar(&pending_vars, &ready_vars, version_pair.get()); } } } - for (auto &var : graph_->dep_vars_) { + for (auto &var : graph_->Get("dep_vars")) { InsertPendingVar(&pending_vars, &ready_vars, var.get()); } - for (auto &op : graph_->ops_) { + for (auto &op : graph_->Get("ops")) { if (op->Inputs().empty()) { // Special case, Op has no input. ready_ops.insert(op.get()); } else { @@ -158,7 +159,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( std::unordered_map> fetched_vars; for (auto &fetch_var_name : fetch_tensors) { - for (auto &var_map : graph_->vars_) { + for (auto &var_map : graph_->Get("vars")) { auto it = var_map.find(fetch_var_name); if (it != var_map.end()) { fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 09973b7a72881464ad9e7776d4aad3d2261a118d..7d0aaf2ddc06e5f255921d06085f90a01db1185c 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -27,6 +27,7 @@ #include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h" +#include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace framework { @@ -39,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, - std::unique_ptr &&graph); + std::unique_ptr &&graph); // Run a SSAGraph by a thread pool // Use topological sort algorithm @@ -52,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { details::OpHandleBase *op); private: - std::unique_ptr graph_; + std::unique_ptr graph_; std::unique_ptr<::ThreadPool> pool_; std::vector local_scopes_; std::vector places_; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 42bbd2b3ff4f67cc5cb15b5af9597cb2ad701ca8..d30aba07a018841636a9e6e9ae747ca1ef5df6fe 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -135,14 +135,8 @@ ParallelExecutor::ParallelExecutor( builder_ = builder_factory.Create(); std::unique_ptr graph = builder_->Build(ProgramToGraph(main_program)); - std::unique_ptr ssa_graph(new details::SSAGraph); - ssa_graph->vars_ = std::move(graph->Get("vars")); - ssa_graph->ops_ = std::move(graph->Get("ops")); - ssa_graph->dep_vars_ = - std::move(graph->Get("dep_vars")); - member_->executor_.reset(new details::ThreadedSSAGraphExecutor( - exec_strategy, member_->local_scopes_, places, std::move(ssa_graph))); + exec_strategy, member_->local_scopes_, places, std::move(graph))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos),